yinuozhang commited on
Commit
813c6b1
·
1 Parent(s): 3e730f5

add functions

Browse files
README.md CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b72853bda66e29cdc787331b4373ad6575f86092f05e4caa775fd50f7cbcda2e
3
- size 206
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b4a57e9caf84b0991a9a349cb28b44049995f4a51ccc3118a0114baf856f36a
3
+ size 839
functions/binding.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import pandas as pd
4
+ import torch.nn as nn
5
+ import esm
6
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
7
+ from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModel
8
+
9
+ base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
10
+
11
+ class ImprovedBindingPredictor(nn.Module):
12
+ def __init__(self,
13
+ esm_dim=1280,
14
+ smiles_dim=768,
15
+ hidden_dim=512,
16
+ n_heads=8,
17
+ n_layers=3,
18
+ dropout=0.1):
19
+ super().__init__()
20
+
21
+ # Define binding thresholds
22
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
23
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
24
+
25
+ # Project to same dimension
26
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
27
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
28
+ self.protein_norm = nn.LayerNorm(hidden_dim)
29
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
30
+
31
+ # Cross attention blocks with layer norm
32
+ self.cross_attention_layers = nn.ModuleList([
33
+ nn.ModuleDict({
34
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
35
+ 'norm1': nn.LayerNorm(hidden_dim),
36
+ 'ffn': nn.Sequential(
37
+ nn.Linear(hidden_dim, hidden_dim * 4),
38
+ nn.ReLU(),
39
+ nn.Dropout(dropout),
40
+ nn.Linear(hidden_dim * 4, hidden_dim)
41
+ ),
42
+ 'norm2': nn.LayerNorm(hidden_dim)
43
+ }) for _ in range(n_layers)
44
+ ])
45
+
46
+ # Prediction heads
47
+ self.shared_head = nn.Sequential(
48
+ nn.Linear(hidden_dim * 2, hidden_dim),
49
+ nn.ReLU(),
50
+ nn.Dropout(dropout),
51
+ )
52
+
53
+ # Regression head
54
+ self.regression_head = nn.Linear(hidden_dim, 1)
55
+
56
+ # Classification head (3 classes: tight, medium, loose binding)
57
+ self.classification_head = nn.Linear(hidden_dim, 3)
58
+
59
+ def get_binding_class(self, affinity):
60
+ """Convert affinity values to class indices
61
+ 0: tight binding (>= 7.5)
62
+ 1: medium binding (6.0-7.5)
63
+ 2: weak binding (< 6.0)
64
+ """
65
+ if isinstance(affinity, torch.Tensor):
66
+ tight_mask = affinity >= self.tight_threshold
67
+ weak_mask = affinity < self.weak_threshold
68
+ medium_mask = ~(tight_mask | weak_mask)
69
+
70
+ classes = torch.zeros_like(affinity, dtype=torch.long)
71
+ classes[medium_mask] = 1
72
+ classes[weak_mask] = 2
73
+ return classes
74
+ else:
75
+ if affinity >= self.tight_threshold:
76
+ return 0 # tight binding
77
+ elif affinity < self.weak_threshold:
78
+ return 2 # weak binding
79
+ else:
80
+ return 1 # medium binding
81
+
82
+ def forward(self, protein_emb, smiles_emb):
83
+ protein = self.protein_norm(self.protein_projection(protein_emb))
84
+ smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
85
+
86
+ #protein = protein.transpose(0, 1)
87
+ #smiles = smiles.transpose(0, 1)
88
+
89
+ # Cross attention layers
90
+ for layer in self.cross_attention_layers:
91
+ # Protein attending to SMILES
92
+ attended_protein = layer['attention'](
93
+ protein, smiles, smiles
94
+ )[0]
95
+ protein = layer['norm1'](protein + attended_protein)
96
+ protein = layer['norm2'](protein + layer['ffn'](protein))
97
+
98
+ # SMILES attending to protein
99
+ attended_smiles = layer['attention'](
100
+ smiles, protein, protein
101
+ )[0]
102
+ smiles = layer['norm1'](smiles + attended_smiles)
103
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
104
+
105
+ # Get sequence-level representations
106
+ protein_pool = torch.mean(protein, dim=0)
107
+ smiles_pool = torch.mean(smiles, dim=0)
108
+
109
+ # Concatenate both representations
110
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
111
+
112
+ # Shared features
113
+ shared_features = self.shared_head(combined)
114
+
115
+ regression_output = self.regression_head(shared_features)
116
+ classification_logits = self.classification_head(shared_features)
117
+
118
+ return regression_output, classification_logits
119
+
120
+ class BindingAffinity:
121
+ def __init__(self, prot_seq, model_type='PeptideCLM'):
122
+ super().__init__()
123
+
124
+ # peptide embeddings
125
+ self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
126
+ self.pep_tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
127
+ f'{base_path}/functions/tokenizer/new_splits.txt')
128
+ self.model = ImprovedBindingPredictor()
129
+ checkpoint = torch.load(f'{base_path}/src/binding/best_model.pt', weights_only=False)
130
+ self.model.load_state_dict(checkpoint['model_state_dict'])
131
+
132
+ self.model.eval()
133
+
134
+ self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
135
+ self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
136
+
137
+ data = [("target", prot_seq)]
138
+ # get tokenized protein
139
+ _, _, prot_tokens = self.prot_tokenizer(data)
140
+ with torch.no_grad():
141
+ results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
142
+ prot_emb = results["representations"][33]
143
+
144
+ self.prot_emb = prot_emb[0]
145
+ self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
146
+
147
+
148
+ def forward(self, input_seqs):
149
+ with torch.no_grad():
150
+ scores = []
151
+ for seq in input_seqs:
152
+ pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
153
+
154
+ with torch.no_grad():
155
+ emb = self.pep_model(input_ids=pep_tokens['input_ids'],
156
+ attention_mask=pep_tokens['attention_mask'],
157
+ output_hidden_states=True)
158
+
159
+ #emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
160
+ pep_emb = emb.last_hidden_state.squeeze(0)
161
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
162
+
163
+ score, logits = self.model.forward(self.prot_emb, pep_emb)
164
+ scores.append(score.item())
165
+ return scores
166
+
167
+ def __call__(self, input_seqs: list):
168
+ return self.forward(input_seqs)
169
+
170
+ def unittest():
171
+ amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
172
+ tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
173
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
174
+ glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
175
+ glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
176
+ ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
177
+
178
+ binding = BindingAffinity(tfr)
179
+
180
+ seq = ["CC[C@H](C)[C@H](NC(=O)[C@H](C)NC(=O)[C@@H](N)Cc1c[nH]cn1)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)O"]
181
+
182
+ scores = binding(seq)
183
+ print(scores)
184
+
185
+ if __name__ == '__main__':
186
+ unittest()
functions/hemolysis.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
8
+ import warnings
9
+ import numpy as np
10
+ from rdkit.Chem import Descriptors, rdMolDescriptors
11
+ from rdkit import Chem, rdBase, DataStructs
12
+ from rdkit.Chem import AllChem
13
+ from typing import List
14
+
15
+ rdBase.DisableLog('rdApp.error')
16
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
17
+ warnings.filterwarnings("ignore", category=UserWarning)
18
+ warnings.filterwarnings("ignore", category=FutureWarning)
19
+
20
+ base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
21
+
22
+ class Hemolysis:
23
+
24
+ def __init__(self):
25
+ self.predictor = xgb.Booster(model_file=f'{base_path}/src/best_model_f1.json')
26
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
27
+ self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
28
+ f'{base_path}/functions/tokenizer/new_splits.txt')
29
+ def generate_embeddings(self, sequences):
30
+ embeddings = []
31
+ for sequence in sequences:
32
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
33
+ with torch.no_grad():
34
+ output = self.emb_model(**tokenized)
35
+ # Mean pooling across sequence length
36
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
37
+ embeddings.append(embedding)
38
+ return np.array(embeddings)
39
+
40
+ def get_scores(self, input_seqs: list):
41
+ scores = np.ones(len(input_seqs))
42
+ features = self.generate_embeddings(input_seqs)
43
+
44
+ if len(features) == 0:
45
+ return scores
46
+
47
+ features = np.nan_to_num(features, nan=0.)
48
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
49
+
50
+ features = xgb.DMatrix(features)
51
+
52
+ probs = self.predictor.predict(features)
53
+ # return the probability of it being not hemolytic
54
+ return scores - probs
55
+
56
+ def __call__(self, input_seqs: list):
57
+ scores = self.get_scores(input_seqs)
58
+ return scores
59
+
60
+ def unittest():
61
+ hemo = Hemolysis()
62
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
63
+
64
+ scores = hemo(input_seqs=seq)
65
+ print(scores)
66
+
67
+
68
+ if __name__ == '__main__':
69
+ unittest()
functions/nonfouling.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
8
+ import warnings
9
+ import numpy as np
10
+ from rdkit import Chem, rdBase, DataStructs
11
+ from transformers import AutoModelForMaskedLM
12
+
13
+
14
+ rdBase.DisableLog('rdApp.error')
15
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
16
+ warnings.filterwarnings("ignore", category=UserWarning)
17
+ warnings.filterwarnings("ignore", category=FutureWarning)
18
+
19
+ base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
20
+
21
+ class Nonfouling:
22
+
23
+ def __init__(self):
24
+ self.predictor = xgb.Booster(model_file=f'{base_path}/src/nonfouling/best_model_f1.json')
25
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
26
+ self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
27
+ f'{base_path}/functions/tokenizer/new_splits.txt')
28
+
29
+ def generate_embeddings(self, sequences):
30
+ embeddings = []
31
+ for sequence in sequences:
32
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
33
+ with torch.no_grad():
34
+ output = self.emb_model(**tokenized)
35
+ # Mean pooling across sequence length
36
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
37
+ embeddings.append(embedding)
38
+ return np.array(embeddings)
39
+
40
+ def get_scores(self, input_seqs: list):
41
+ scores = np.zeros(len(input_seqs))
42
+ features = self.generate_embeddings(input_seqs)
43
+
44
+ if len(features) == 0:
45
+ return scores
46
+
47
+ features = np.nan_to_num(features, nan=0.)
48
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
49
+
50
+ features = xgb.DMatrix(features)
51
+
52
+ scores = self.predictor.predict(features)
53
+ # return the probability of it being not hemolytic
54
+ return scores
55
+
56
+ def __call__(self, input_seqs: list):
57
+ scores = self.get_scores(input_seqs)
58
+ return scores
59
+
60
+ def unittest():
61
+ nf = Nonfouling()
62
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
63
+
64
+ scores = nf(input_seqs=seq)
65
+ print(scores)
66
+
67
+
68
+ if __name__ == '__main__':
69
+ unittest()
functions/permeability.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
8
+ import warnings
9
+ import numpy as np
10
+ from rdkit.Chem import Descriptors, rdMolDescriptors
11
+ from rdkit import Chem, rdBase, DataStructs
12
+ from rdkit.Chem import AllChem
13
+ from typing import List
14
+
15
+ base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
16
+
17
+ rdBase.DisableLog('rdApp.error')
18
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
19
+ warnings.filterwarnings("ignore", category=UserWarning)
20
+ warnings.filterwarnings("ignore", category=FutureWarning)
21
+
22
+ def fingerprints_from_smiles(smiles: List, size=2048):
23
+ """ Create ECFP fingerprints of smiles, with validity check """
24
+ fps = []
25
+ valid_mask = []
26
+ for i, smile in enumerate(smiles):
27
+ mol = Chem.MolFromSmiles(smile)
28
+ valid_mask.append(int(mol is not None))
29
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
30
+ fps.append(fp)
31
+
32
+ fps = np.concatenate(fps, axis=0)
33
+ return fps, valid_mask
34
+
35
+
36
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
37
+ """ Create ECFP fingerprint of a molecule """
38
+ if hashed:
39
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
40
+ else:
41
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
42
+ fp_np = np.zeros((1,))
43
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
44
+ return fp_np.reshape(1, -1)
45
+
46
+ def getMolDescriptors(mol, missingVal=0):
47
+ """ calculate the full list of descriptors for a molecule """
48
+
49
+ values, names = [], []
50
+ for nm, fn in Descriptors._descList:
51
+ try:
52
+ val = fn(mol)
53
+ except:
54
+ val = missingVal
55
+ values.append(val)
56
+ names.append(nm)
57
+
58
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
59
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
60
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
61
+
62
+ for nm, fn in custom_descriptors.items():
63
+ try:
64
+ val = fn(mol)
65
+ except:
66
+ val = missingVal
67
+ values.append(val)
68
+ names.append(nm)
69
+ return values, names
70
+
71
+ def get_pep_dps_from_smi(smi):
72
+ try:
73
+ mol = Chem.MolFromSmiles(smi)
74
+ except:
75
+ print(f"convert smi {smi} to molecule failed!")
76
+ mol = None
77
+
78
+ dps, _ = getMolDescriptors(mol)
79
+ return np.array(dps)
80
+
81
+
82
+ def get_pep_dps(smi_list):
83
+ if len(smi_list) == 0:
84
+ return np.zeros((0, 213))
85
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
86
+
87
+ def check_smi_validity(smiles: list):
88
+ valid_smi, valid_idx = [], []
89
+ for idx, smi in enumerate(smiles):
90
+ try:
91
+ mol = Chem.MolFromSmiles(smi) if smi else None
92
+ if mol:
93
+ valid_smi.append(smi)
94
+ valid_idx.append(idx)
95
+ except Exception as e:
96
+ # logger.debug(f'Error: {e} in smiles {smi}')
97
+ pass
98
+ return valid_smi, valid_idx
99
+
100
+ class Permeability:
101
+
102
+ def __init__(self):
103
+ self.predictor = xgb.Booster(model_file=f'{base_path}/src/permeability/best_model.json')
104
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
105
+ self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
106
+ f'{base_path}/functions/tokenizer/new_splits.txt')
107
+
108
+ def generate_embeddings(self, sequences):
109
+ embeddings = []
110
+ for sequence in sequences:
111
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
112
+ with torch.no_grad():
113
+ output = self.emb_model(**tokenized)
114
+ # Mean pooling across sequence length
115
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
116
+ embeddings.append(embedding)
117
+ return np.array(embeddings)
118
+
119
+ def get_features(self, input_seqs: list, dps=False, fps=False):
120
+ #valid_smiles, valid_idxes = check_smi_validity(input_seqs)
121
+
122
+
123
+ if fps:
124
+ fingerprints = fingerprints_from_smiles(input_seqs)[0]
125
+ else:
126
+ fingerprints = torch.empty((len(input_seqs), 0))
127
+
128
+ if dps:
129
+ descriptors = get_pep_dps(input_seqs)
130
+ else:
131
+ descriptors = torch.empty((len(input_seqs), 0))
132
+
133
+ embeddings = self.generate_embeddings(input_seqs)
134
+ # logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
135
+
136
+ features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
137
+
138
+ return features
139
+
140
+ def get_scores(self, input_seqs: list):
141
+ scores = -10 * np.ones(len(input_seqs))
142
+ features = self.get_features(input_seqs)
143
+
144
+ if len(features) == 0:
145
+ return scores
146
+
147
+ features = np.nan_to_num(features, nan=0.)
148
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
149
+
150
+ features = xgb.DMatrix(features)
151
+
152
+ scores = self.predictor.predict(features)
153
+ return scores
154
+
155
+ def __call__(self, input_seqs: list):
156
+ scores = self.get_scores(input_seqs)
157
+ return scores
158
+
159
+ def unittest():
160
+ permeability = Permeability()
161
+ seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
162
+ scores = permeability(input_seqs=seq)
163
+ print(scores)
164
+
165
+
166
+ if __name__ == '__main__':
167
+ unittest()
functions/solubility.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
8
+ import warnings
9
+ import numpy as np
10
+ from rdkit.Chem import Descriptors, rdMolDescriptors
11
+ from rdkit import Chem, rdBase, DataStructs
12
+ from rdkit.Chem import AllChem
13
+ from typing import List
14
+ from transformers import AutoModelForMaskedLM
15
+
16
+
17
+ rdBase.DisableLog('rdApp.error')
18
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
19
+ warnings.filterwarnings("ignore", category=UserWarning)
20
+ warnings.filterwarnings("ignore", category=FutureWarning)
21
+
22
+ base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
23
+
24
+ class Solubility:
25
+ def __init__(self):
26
+ self.predictor = xgb.Booster(model_file=f'{base_path}/src/solubility/best_model_f1.json')
27
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
28
+ self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
29
+ f'{base_path}/functions/tokenizer/new_splits.txt')
30
+
31
+ def generate_embeddings(self, sequences):
32
+ embeddings = []
33
+ for sequence in sequences:
34
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
35
+ with torch.no_grad():
36
+ output = self.emb_model(**tokenized)
37
+ # Mean pooling across sequence length
38
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
39
+ embeddings.append(embedding)
40
+ return np.array(embeddings)
41
+
42
+ def get_scores(self, input_seqs: list):
43
+ scores = np.zeros(len(input_seqs))
44
+ features = self.generate_embeddings(input_seqs)
45
+
46
+ if len(features) == 0:
47
+ return scores
48
+
49
+ features = np.nan_to_num(features, nan=0.)
50
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
51
+
52
+ features = xgb.DMatrix(features)
53
+
54
+ scores = self.predictor.predict(features)
55
+ return scores
56
+
57
+ def __call__(self, input_seqs: list):
58
+ scores = self.get_scores(input_seqs)
59
+ return scores
60
+
61
+ def unittest():
62
+ solubility = Solubility()
63
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
64
+ scores = solubility(input_seqs=seq)
65
+ print(scores)
66
+
67
+ if __name__ == '__main__':
68
+ unittest()
functions/tokenizer/__pycache__/my_tokenizers.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
functions/tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import os
4
+ import re
5
+ import codecs
6
+ import unicodedata
7
+ from typing import List, Optional
8
+ from transformers import PreTrainedTokenizer
9
+ from SmilesPE.tokenizer import SPE_Tokenizer
10
+
11
+ def load_vocab(vocab_file):
12
+ """Loads a vocabulary file into a dictionary."""
13
+ vocab = collections.OrderedDict()
14
+ with open(vocab_file, "r", encoding="utf-8") as reader:
15
+ tokens = reader.readlines()
16
+ for index, token in enumerate(tokens):
17
+ token = token.rstrip("\n")
18
+ vocab[token] = index
19
+ return vocab
20
+
21
+ class Atomwise_Tokenizer(object):
22
+ """Run atom-level SMILES tokenization"""
23
+
24
+ def __init__(self):
25
+ """ Constructs a atom-level Tokenizer.
26
+ """
27
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
28
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
29
+
30
+ self.regex = re.compile(self.regex_pattern)
31
+
32
+ def tokenize(self, text):
33
+ """ Basic Tokenization of a SMILES.
34
+ """
35
+ tokens = [token for token in self.regex.findall(text)]
36
+ return tokens
37
+
38
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
39
+ r"""
40
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
41
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
42
+ should refer to the superclass for more information regarding methods.
43
+ Args:
44
+ vocab_file (:obj:`string`):
45
+ File containing the vocabulary.
46
+ spe_file (:obj:`string`):
47
+ File containing the trained SMILES Pair Encoding vocabulary.
48
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
49
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
50
+ token instead.
51
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
52
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
53
+ for sequence classification or for a text and a question for question answering.
54
+ It is also used as the last token of a sequence built with special tokens.
55
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
56
+ The token used for padding, for example when batching sequences of different lengths.
57
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
58
+ The classifier token which is used when doing sequence classification (classification of the whole
59
+ sequence instead of per-token classification). It is the first token of the sequence when built with
60
+ special tokens.
61
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
62
+ The token used for masking values. This is the token used when training this model with masked language
63
+ modeling. This is the token which the model will try to predict.
64
+ """
65
+
66
+ def __init__(self, vocab_file, spe_file,
67
+ unk_token="[UNK]",
68
+ sep_token="[SEP]",
69
+ pad_token="[PAD]",
70
+ cls_token="[CLS]",
71
+ mask_token="[MASK]",
72
+ **kwargs):
73
+ if not os.path.isfile(vocab_file):
74
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
75
+ if not os.path.isfile(spe_file):
76
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
77
+
78
+ self.vocab = load_vocab(vocab_file)
79
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
80
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
81
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
82
+
83
+ super().__init__(
84
+ unk_token=unk_token,
85
+ sep_token=sep_token,
86
+ pad_token=pad_token,
87
+ cls_token=cls_token,
88
+ mask_token=mask_token,
89
+ **kwargs)
90
+
91
+ @property
92
+ def vocab_size(self):
93
+ return len(self.vocab)
94
+
95
+ def get_vocab(self):
96
+ return dict(self.vocab, **self.added_tokens_encoder)
97
+
98
+ def _tokenize(self, text):
99
+ return self.spe_tokenizer.tokenize(text).split(' ')
100
+
101
+ def _convert_token_to_id(self, token):
102
+ """ Converts a token (str) in an id using the vocab. """
103
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
104
+
105
+ def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
106
+ text = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
107
+ return self.convert_tokens_to_string(text)
108
+
109
+ def _convert_id_to_token(self, index):
110
+ """Converts an index (integer) in a token (str) using the vocab."""
111
+ return self.ids_to_tokens.get(index, self.unk_token)
112
+
113
+ def convert_tokens_to_string(self, tokens):
114
+ """ Converts a sequence of tokens (string) in a single string. """
115
+ out_string = " ".join(tokens).replace(" ##", "").strip()
116
+ return out_string
117
+
118
+ def build_inputs_with_special_tokens(
119
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
120
+ ) -> List[int]:
121
+ """
122
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
123
+ by concatenating and adding special tokens.
124
+ A BERT sequence has the following format:
125
+ - single sequence: ``[CLS] X [SEP]``
126
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
127
+ Args:
128
+ token_ids_0 (:obj:`List[int]`):
129
+ List of IDs to which the special tokens will be added
130
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
131
+ Optional second list of IDs for sequence pairs.
132
+ Returns:
133
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
134
+ """
135
+ if token_ids_1 is None:
136
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
137
+ cls = [self.cls_token_id]
138
+ sep = [self.sep_token_id]
139
+ return cls + token_ids_0 + sep + token_ids_1 + sep
140
+
141
+ def get_special_tokens_mask(
142
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
143
+ ) -> List[int]:
144
+ """
145
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
146
+ special tokens using the tokenizer ``prepare_for_model`` method.
147
+ Args:
148
+ token_ids_0 (:obj:`List[int]`):
149
+ List of ids.
150
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
151
+ Optional second list of IDs for sequence pairs.
152
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
153
+ Set to True if the token list is already formatted with special tokens for the model
154
+ Returns:
155
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
156
+ """
157
+
158
+ if already_has_special_tokens:
159
+ if token_ids_1 is not None:
160
+ raise ValueError(
161
+ "You should not supply a second sequence if the provided sequence of "
162
+ "ids is already formated with special tokens for the model."
163
+ )
164
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
165
+
166
+ if token_ids_1 is not None:
167
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
168
+ return [1] + ([0] * len(token_ids_0)) + [1]
169
+
170
+ def create_token_type_ids_from_sequences(
171
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
172
+ ) -> List[int]:
173
+ """
174
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
175
+ A BERT sequence pair mask has the following format:
176
+ ::
177
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
178
+ | first sequence | second sequence |
179
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
180
+ Args:
181
+ token_ids_0 (:obj:`List[int]`):
182
+ List of ids.
183
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
184
+ Optional second list of IDs for sequence pairs.
185
+ Returns:
186
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
187
+ sequence(s).
188
+ """
189
+ sep = [self.sep_token_id]
190
+ cls = [self.cls_token_id]
191
+ if token_ids_1 is None:
192
+ return len(cls + token_ids_0 + sep) * [0]
193
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
194
+
195
+ def save_vocabulary(self, vocab_path):
196
+ """
197
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
198
+ Args:
199
+ vocab_path (:obj:`str`):
200
+ The directory in which to save the vocabulary.
201
+ Returns:
202
+ :obj:`Tuple(str)`: Paths to the files saved.
203
+ """
204
+ index = 0
205
+ if os.path.isdir(vocab_path):
206
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
207
+ else:
208
+ vocab_file = vocab_path
209
+ with open(vocab_file, "w", encoding="utf-8") as writer:
210
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
211
+ if index != token_index:
212
+ logger.warning(
213
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
214
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
215
+ )
216
+ index = token_index
217
+ writer.write(token + "\n")
218
+ index += 1
219
+ return (vocab_file,)
220
+
221
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
222
+ r"""
223
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
224
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
225
+ should refer to the superclass for more information regarding methods.
226
+ Args:
227
+ vocab_file (:obj:`string`):
228
+ File containing the vocabulary.
229
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
230
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
231
+ token instead.
232
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
233
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
234
+ for sequence classification or for a text and a question for question answering.
235
+ It is also used as the last token of a sequence built with special tokens.
236
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
237
+ The token used for padding, for example when batching sequences of different lengths.
238
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
239
+ The classifier token which is used when doing sequence classification (classification of the whole
240
+ sequence instead of per-token classification). It is the first token of the sequence when built with
241
+ special tokens.
242
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
243
+ The token used for masking values. This is the token used when training this model with masked language
244
+ modeling. This is the token which the model will try to predict.
245
+ """
246
+
247
+ def __init__(
248
+ self,
249
+ vocab_file,
250
+ unk_token="[UNK]",
251
+ sep_token="[SEP]",
252
+ pad_token="[PAD]",
253
+ cls_token="[CLS]",
254
+ mask_token="[MASK]",
255
+ **kwargs
256
+ ):
257
+ super().__init__(
258
+ unk_token=unk_token,
259
+ sep_token=sep_token,
260
+ pad_token=pad_token,
261
+ cls_token=cls_token,
262
+ mask_token=mask_token,
263
+ **kwargs,
264
+ )
265
+
266
+ if not os.path.isfile(vocab_file):
267
+ raise ValueError(
268
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
269
+ )
270
+ self.vocab = load_vocab(vocab_file)
271
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
272
+ self.tokenizer = Atomwise_Tokenizer()
273
+
274
+ @property
275
+ def vocab_size(self):
276
+ return len(self.vocab)
277
+
278
+ def get_vocab(self):
279
+ return dict(self.vocab, **self.added_tokens_encoder)
280
+
281
+ def _tokenize(self, text):
282
+ return self.tokenizer.tokenize(text)
283
+
284
+ def _convert_token_to_id(self, token):
285
+ """ Converts a token (str) in an id using the vocab. """
286
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
287
+
288
+ def _convert_id_to_token(self, index):
289
+ """Converts an index (integer) in a token (str) using the vocab."""
290
+ return self.ids_to_tokens.get(index, self.unk_token)
291
+
292
+ def convert_tokens_to_string(self, tokens):
293
+ """ Converts a sequence of tokens (string) in a single string. """
294
+ out_string = " ".join(tokens).replace(" ##", "").strip()
295
+ return out_string
296
+
297
+ def build_inputs_with_special_tokens(
298
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
299
+ ) -> List[int]:
300
+ """
301
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
302
+ by concatenating and adding special tokens.
303
+ A BERT sequence has the following format:
304
+ - single sequence: ``[CLS] X [SEP]``
305
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
306
+ Args:
307
+ token_ids_0 (:obj:`List[int]`):
308
+ List of IDs to which the special tokens will be added
309
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
310
+ Optional second list of IDs for sequence pairs.
311
+ Returns:
312
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
313
+ """
314
+ if token_ids_1 is None:
315
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
316
+ cls = [self.cls_token_id]
317
+ sep = [self.sep_token_id]
318
+ return cls + token_ids_0 + sep + token_ids_1 + sep
319
+
320
+ def get_special_tokens_mask(
321
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
322
+ ) -> List[int]:
323
+ """
324
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
325
+ special tokens using the tokenizer ``prepare_for_model`` method.
326
+ Args:
327
+ token_ids_0 (:obj:`List[int]`):
328
+ List of ids.
329
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
330
+ Optional second list of IDs for sequence pairs.
331
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
332
+ Set to True if the token list is already formatted with special tokens for the model
333
+ Returns:
334
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
335
+ """
336
+
337
+ if already_has_special_tokens:
338
+ if token_ids_1 is not None:
339
+ raise ValueError(
340
+ "You should not supply a second sequence if the provided sequence of "
341
+ "ids is already formated with special tokens for the model."
342
+ )
343
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
344
+
345
+ if token_ids_1 is not None:
346
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
347
+ return [1] + ([0] * len(token_ids_0)) + [1]
348
+
349
+ def create_token_type_ids_from_sequences(
350
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
351
+ ) -> List[int]:
352
+ """
353
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
354
+ A BERT sequence pair mask has the following format:
355
+ ::
356
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
357
+ | first sequence | second sequence |
358
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ Returns:
365
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
366
+ sequence(s).
367
+ """
368
+ sep = [self.sep_token_id]
369
+ cls = [self.cls_token_id]
370
+ if token_ids_1 is None:
371
+ return len(cls + token_ids_0 + sep) * [0]
372
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
373
+
374
+ def save_vocabulary(self, vocab_path):
375
+ """
376
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
377
+ Args:
378
+ vocab_path (:obj:`str`):
379
+ The directory in which to save the vocabulary.
380
+ Returns:
381
+ :obj:`Tuple(str)`: Paths to the files saved.
382
+ """
383
+ index = 0
384
+ if os.path.isdir(vocab_path):
385
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
386
+ else:
387
+ vocab_file = vocab_path
388
+ with open(vocab_file, "w", encoding="utf-8") as writer:
389
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
390
+ if index != token_index:
391
+ logger.warning(
392
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
393
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
394
+ )
395
+ index = token_index
396
+ writer.write(token + "\n")
397
+ index += 1
398
+ return (vocab_file,)
functions/tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
functions/tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ ccn
96
+ cco
97
+ ccs
98
+ cn
99
+ cnc
100
+ cnn
101
+ co
102
+ coc
103
+ cs
104
+ csc
105
+ csn
106
+ e
107
+ g
108
+ i
109
+ l
110
+ n
111
+ nc
112
+ ncc
113
+ ncn
114
+ nco
115
+ ncs
116
+ nn
117
+ nnc
118
+ nnn
119
+ no
120
+ noc
121
+ non
122
+ ns
123
+ nsc
124
+ nsn
125
+ o
126
+ oc
127
+ occ
128
+ on
129
+ p
130
+ r
131
+ s
132
+ sc
133
+ scc
134
+ scn
135
+ sn
136
+ t
137
+ c1
138
+ c2
139
+ c3
140
+ c4
141
+ c5
142
+ c6
143
+ c7
144
+ c8
145
+ c9
146
+ n1
147
+ n2
148
+ n3
149
+ n4
150
+ n5
151
+ n6
152
+ n7
153
+ n8
154
+ n9
155
+ O1
156
+ O2
157
+ O3
158
+ O4
159
+ O5
160
+ O6
161
+ O7
162
+ O8
163
+ O9
164
+ (c1
165
+ (c2
166
+ c1)
167
+ c2)
168
+ (n1
169
+ (n2
170
+ n1)
171
+ n2)
172
+ (O1
173
+ (O2
174
+ O2)
175
+ =O
176
+ =C
177
+ =c
178
+ =N
179
+ =n
180
+ =CC
181
+ =CN
182
+ =Cc
183
+ =cc
184
+ =NC
185
+ =Nc
186
+ =nC
187
+ =nc
188
+ #C
189
+ #CC
190
+ #CN
191
+ #N
192
+ #NC
193
+ #NN
194
+ (C
195
+ C)
196
+ (O
197
+ O)
198
+ (N
199
+ N)
200
+ NP
201
+ PN
202
+ CP
203
+ PC
204
+ NS
205
+ SN
206
+ SP
207
+ PS
208
+ C(=O)
209
+ (/Br)
210
+ (/C#N)
211
+ (/C)
212
+ (/C=N)
213
+ (/C=O)
214
+ (/CBr)
215
+ (/CC)
216
+ (/CCC)
217
+ (/CCF)
218
+ (/CCN)
219
+ (/CCO)
220
+ (/CCl)
221
+ (/CI)
222
+ (/CN)
223
+ (/CO)
224
+ (/CS)
225
+ (/Cl)
226
+ (/F)
227
+ (/I)
228
+ (/N)
229
+ (/NC)
230
+ (/NCC)
231
+ (/NO)
232
+ (/O)
233
+ (/OC)
234
+ (/OCC)
235
+ (/S)
236
+ (/SC)
237
+ (=C)
238
+ (=C/C)
239
+ (=C/F)
240
+ (=C/I)
241
+ (=C/N)
242
+ (=C/O)
243
+ (=CBr)
244
+ (=CC)
245
+ (=CCF)
246
+ (=CCN)
247
+ (=CCO)
248
+ (=CCl)
249
+ (=CF)
250
+ (=CI)
251
+ (=CN)
252
+ (=CO)
253
+ (=C\\C)
254
+ (=C\\F)
255
+ (=C\\I)
256
+ (=C\\N)
257
+ (=C\\O)
258
+ (=N)
259
+ (=N/C)
260
+ (=N/N)
261
+ (=N/O)
262
+ (=NBr)
263
+ (=NC)
264
+ (=NCC)
265
+ (=NCl)
266
+ (=NN)
267
+ (=NO)
268
+ (=NOC)
269
+ (=N\\C)
270
+ (=N\\N)
271
+ (=N\\O)
272
+ (=O)
273
+ (=S)
274
+ (B)
275
+ (Br)
276
+ (C#C)
277
+ (C#CC)
278
+ (C#CI)
279
+ (C#CO)
280
+ (C#N)
281
+ (C#SN)
282
+ (C)
283
+ (C=C)
284
+ (C=CF)
285
+ (C=CI)
286
+ (C=N)
287
+ (C=NN)
288
+ (C=NO)
289
+ (C=O)
290
+ (C=S)
291
+ (CBr)
292
+ (CC#C)
293
+ (CC#N)
294
+ (CC)
295
+ (CC=C)
296
+ (CC=O)
297
+ (CCBr)
298
+ (CCC)
299
+ (CCCC)
300
+ (CCCF)
301
+ (CCCI)
302
+ (CCCN)
303
+ (CCCO)
304
+ (CCCS)
305
+ (CCCl)
306
+ (CCF)
307
+ (CCI)
308
+ (CCN)
309
+ (CCNC)
310
+ (CCNN)
311
+ (CCNO)
312
+ (CCO)
313
+ (CCOC)
314
+ (CCON)
315
+ (CCS)
316
+ (CCSC)
317
+ (CCl)
318
+ (CF)
319
+ (CI)
320
+ (CN)
321
+ (CN=O)
322
+ (CNC)
323
+ (CNCC)
324
+ (CNCO)
325
+ (CNN)
326
+ (CNNC)
327
+ (CNO)
328
+ (CNOC)
329
+ (CO)
330
+ (COC)
331
+ (COCC)
332
+ (COCI)
333
+ (COCN)
334
+ (COCO)
335
+ (COF)
336
+ (CON)
337
+ (COO)
338
+ (CS)
339
+ (CSC)
340
+ (CSCC)
341
+ (CSCF)
342
+ (CSO)
343
+ (Cl)
344
+ (F)
345
+ (I)
346
+ (N)
347
+ (N=N)
348
+ (N=NO)
349
+ (N=O)
350
+ (N=S)
351
+ (NBr)
352
+ (NC#N)
353
+ (NC)
354
+ (NC=N)
355
+ (NC=O)
356
+ (NC=S)
357
+ (NCBr)
358
+ (NCC)
359
+ (NCCC)
360
+ (NCCF)
361
+ (NCCN)
362
+ (NCCO)
363
+ (NCCS)
364
+ (NCCl)
365
+ (NCNC)
366
+ (NCO)
367
+ (NCS)
368
+ (NCl)
369
+ (NN)
370
+ (NN=O)
371
+ (NNC)
372
+ (NO)
373
+ (NOC)
374
+ (O)
375
+ (OC#N)
376
+ (OC)
377
+ (OC=C)
378
+ (OC=O)
379
+ (OC=S)
380
+ (OCBr)
381
+ (OCC)
382
+ (OCCC)
383
+ (OCCF)
384
+ (OCCI)
385
+ (OCCN)
386
+ (OCCO)
387
+ (OCCS)
388
+ (OCCl)
389
+ (OCF)
390
+ (OCI)
391
+ (OCO)
392
+ (OCOC)
393
+ (OCON)
394
+ (OCSC)
395
+ (OCl)
396
+ (OI)
397
+ (ON)
398
+ (OO)
399
+ (OOC)
400
+ (OOCC)
401
+ (OOSN)
402
+ (OSC)
403
+ (P)
404
+ (S)
405
+ (SC#N)
406
+ (SC)
407
+ (SCC)
408
+ (SCCC)
409
+ (SCCF)
410
+ (SCCN)
411
+ (SCCO)
412
+ (SCCS)
413
+ (SCCl)
414
+ (SCF)
415
+ (SCN)
416
+ (SCOC)
417
+ (SCSC)
418
+ (SCl)
419
+ (SI)
420
+ (SN)
421
+ (SN=O)
422
+ (SO)
423
+ (SOC)
424
+ (SOOO)
425
+ (SS)
426
+ (SSC)
427
+ (SSCC)
428
+ ([At])
429
+ ([O-])
430
+ ([O])
431
+ ([S-])
432
+ (\\Br)
433
+ (\\C#N)
434
+ (\\C)
435
+ (\\C=N)
436
+ (\\C=O)
437
+ (\\CBr)
438
+ (\\CC)
439
+ (\\CCC)
440
+ (\\CCO)
441
+ (\\CCl)
442
+ (\\CF)
443
+ (\\CN)
444
+ (\\CNC)
445
+ (\\CO)
446
+ (\\COC)
447
+ (\\Cl)
448
+ (\\F)
449
+ (\\I)
450
+ (\\N)
451
+ (\\NC)
452
+ (\\NCC)
453
+ (\\NN)
454
+ (\\NO)
455
+ (\\NOC)
456
+ (\\O)
457
+ (\\OC)
458
+ (\\OCC)
459
+ (\\ON)
460
+ (\\S)
461
+ (\\SC)
462
+ (\\SCC)
463
+ [Ag+]
464
+ [Ag-4]
465
+ [Ag]
466
+ [Al-3]
467
+ [Al]
468
+ [As+]
469
+ [AsH3]
470
+ [AsH]
471
+ [As]
472
+ [At]
473
+ [B-]
474
+ [B@-]
475
+ [B@@-]
476
+ [BH-]
477
+ [BH2-]
478
+ [BH3-]
479
+ [B]
480
+ [Ba]
481
+ [Br+2]
482
+ [BrH]
483
+ [Br]
484
+ [C+]
485
+ [C-]
486
+ [C@@H]
487
+ [C@@]
488
+ [C@H]
489
+ [C@]
490
+ [CH-]
491
+ [CH2]
492
+ [CH3]
493
+ [CH]
494
+ [C]
495
+ [CaH2]
496
+ [Ca]
497
+ [Cl+2]
498
+ [Cl+3]
499
+ [Cl+]
500
+ [Cs]
501
+ [FH]
502
+ [F]
503
+ [H]
504
+ [He]
505
+ [I+2]
506
+ [I+3]
507
+ [I+]
508
+ [IH]
509
+ [I]
510
+ [K]
511
+ [Kr]
512
+ [Li+]
513
+ [LiH]
514
+ [MgH2]
515
+ [Mg]
516
+ [N+]
517
+ [N-]
518
+ [N@+]
519
+ [N@@+]
520
+ [N@@]
521
+ [N@]
522
+ [NH+]
523
+ [NH-]
524
+ [NH2+]
525
+ [NH3]
526
+ [NH]
527
+ [N]
528
+ [Na]
529
+ [O+]
530
+ [O-]
531
+ [OH+]
532
+ [OH2]
533
+ [OH]
534
+ [O]
535
+ [P+]
536
+ [P@+]
537
+ [P@@+]
538
+ [P@@]
539
+ [P@]
540
+ [PH2]
541
+ [PH]
542
+ [P]
543
+ [Ra]
544
+ [Rb]
545
+ [S+]
546
+ [S-]
547
+ [S@+]
548
+ [S@@+]
549
+ [S@@]
550
+ [S@]
551
+ [SH+]
552
+ [SH2]
553
+ [SH]
554
+ [S]
555
+ [Se+]
556
+ [Se-2]
557
+ [SeH2]
558
+ [SeH]
559
+ [Se]
560
+ [Si@]
561
+ [SiH2]
562
+ [SiH]
563
+ [Si]
564
+ [SrH2]
565
+ [TeH]
566
+ [Te]
567
+ [Xe]
568
+ [Zn+2]
569
+ [Zn-2]
570
+ [Zn]
571
+ [b-]
572
+ [c+]
573
+ [c-]
574
+ [cH-]
575
+ [cH]
576
+ [c]
577
+ [n+]
578
+ [n-]
579
+ [nH]
580
+ [n]
581
+ [o+]
582
+ [s+]
583
+ [se+]
584
+ [se]
585
+ [te+]
586
+ [te]
scoring_functions.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import io
3
+ import subprocess
4
+ import warnings
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing import List
8
+ from loguru import logger
9
+ from tqdm import tqdm
10
+ from rdkit import Chem, rdBase, DataStructs
11
+ from rdkit.Chem import AllChem
12
+ import torch
13
+ from functions.binding.binding import BindingAffinity
14
+ from functions.permeability.permeability import Permeability
15
+ from functions.solubility.solubility import Solubility
16
+ from functions.hemolysis.hemolysis import Hemolysis
17
+ from functions.nonfouling.nonfouling import Nonfouling
18
+
19
+ class ScoringFunctions:
20
+ def __init__(self, score_func_names=None, prot_seqs=[]):
21
+ """
22
+ Class for generating score vectors given generated sequence
23
+
24
+ Args:
25
+ score_func_names: list of scoring function names to be evaluated
26
+ score_weights: weights to scale scores (default: 1)
27
+ target_protein: sequence of target protein binder
28
+ """
29
+ if score_func_names is None:
30
+ # just do unmasking based on validity of peptide bonds
31
+ self.score_func_names = []
32
+ else:
33
+ self.score_func_names = score_func_names
34
+
35
+ # self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
36
+
37
+ # binding affinities
38
+ self.target_protein = prot_seqs
39
+ print(len(prot_seqs))
40
+
41
+ if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
42
+ binding_affinity1 = BindingAffinity(prot_seqs[0])
43
+ binding_affinity2 = None
44
+ elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
45
+ binding_affinity1 = BindingAffinity(prot_seqs[0])
46
+ binding_affinity2 = BindingAffinity(prot_seqs[1])
47
+ else:
48
+ print("here")
49
+ binding_affinity1 = None
50
+ binding_affinity2 = None
51
+
52
+ permeability = Permeability()
53
+ sol = Solubility()
54
+ nonfouling = Nonfouling()
55
+ hemo = Hemolysis()
56
+
57
+ self.all_funcs = {'binding_affinity1': binding_affinity1,
58
+ 'binding_affinity2': binding_affinity2,
59
+ 'permeability': permeability,
60
+ 'nonfouling': nonfouling,
61
+ 'solubility': sol,
62
+ 'hemolysis': hemo
63
+ }
64
+
65
+ def forward(self, input_seqs):
66
+ scores = []
67
+
68
+ for i, score_func in enumerate(self.score_func_names):
69
+ score = self.all_funcs[score_func](input_seqs = input_seqs)
70
+
71
+ scores.append(score)
72
+
73
+ # convert to numpy arrays with shape (num_sequences, num_functions)
74
+ scores = np.float32(scores).T
75
+
76
+ return scores
77
+
78
+ def __call__(self, input_seqs: list):
79
+ return self.forward(input_seqs)
80
+
81
+
82
+ def unittest():
83
+ amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
84
+ tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
85
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
86
+ glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
87
+ glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
88
+ ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
89
+ cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
90
+
91
+ num_iter = 0
92
+ score_func_times = [0, 1, 2, 3, 4, 5]
93
+
94
+ scoring = ScoringFunctions(score_func_names=['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'], prot_seqs=[tfr])
95
+
96
+ smiles = ['N2[C@H](CC(C)C)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](Cc1ccccc1C(F)(F)F)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](CCSC)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](CC(=O)N)C2(=O)']
97
+
98
+ scores = scoring(input_seqs=smiles)
99
+ print(scores)
100
+ print(len(scores))
101
+
102
+ if __name__ == '__main__':
103
+ unittest()
train/binary_xg.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.metrics import precision_recall_curve, f1_score
6
+ import optuna
7
+ from optuna.trial import TrialState
8
+ import xgboost as xgb
9
+ import os
10
+ from datasets import load_from_disk
11
+ from lightning.pytorch import seed_everything
12
+ from rdkit import Chem, rdBase, DataStructs
13
+ from typing import List
14
+ from rdkit.Chem import AllChem
15
+ import matplotlib.pyplot as plt
16
+ from sklearn.metrics import accuracy_score, roc_auc_score
17
+
18
+ base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
19
+
20
+ def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path):
21
+ """
22
+ Saves the true and predicted values for training and validation sets, and generates binary classification plots.
23
+
24
+ Parameters:
25
+ y_true_train (array): True labels for the training set.
26
+ y_pred_train (array): Predicted probabilities for the training set.
27
+ y_true_val (array): True labels for the validation set.
28
+ y_pred_val (array): Predicted probabilities for the validation set.
29
+ threshold (float): Classification threshold for predictions.
30
+ output_path (str): Directory to save the CSV files and plots.
31
+ """
32
+ os.makedirs(output_path, exist_ok=True)
33
+
34
+ # Convert probabilities to binary predictions
35
+ y_pred_train_binary = (y_pred_train >= threshold).astype(int)
36
+ y_pred_val_binary = (y_pred_val >= threshold).astype(int)
37
+
38
+ # Save training predictions
39
+ train_df = pd.DataFrame({
40
+ 'True Label': y_true_train,
41
+ 'Predicted Probability': y_pred_train,
42
+ 'Predicted Label': y_pred_train_binary
43
+ })
44
+ train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False)
45
+
46
+ # Save validation predictions
47
+ val_df = pd.DataFrame({
48
+ 'True Label': y_true_val,
49
+ 'Predicted Probability': y_pred_val,
50
+ 'Predicted Label': y_pred_val_binary
51
+ })
52
+ val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False)
53
+
54
+ # Plot training predictions
55
+ plot_binary_correlation(
56
+ y_true_train,
57
+ y_pred_train,
58
+ threshold,
59
+ title="Training Set Binary Classification Plot",
60
+ output_file=os.path.join(output_path, 'train_classification_plot.png')
61
+ )
62
+
63
+ # Plot validation predictions
64
+ plot_binary_correlation(
65
+ y_true_val,
66
+ y_pred_val,
67
+ threshold,
68
+ title="Validation Set Binary Classification Plot",
69
+ output_file=os.path.join(output_path, 'val_classification_plot.png')
70
+ )
71
+
72
+ def plot_binary_correlation(y_true, y_pred, threshold, title, output_file):
73
+ """
74
+ Generates a scatter plot for binary classification and saves it to a file.
75
+
76
+ Parameters:
77
+ y_true (array): True labels.
78
+ y_pred (array): Predicted probabilities.
79
+ threshold (float): Classification threshold for predictions.
80
+ title (str): Title of the plot.
81
+ output_file (str): Path to save the plot.
82
+ """
83
+ # Scatter plot
84
+ plt.figure(figsize=(10, 8))
85
+ plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
86
+
87
+ # Add threshold line
88
+ plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
89
+
90
+ # Add annotations
91
+ plt.title(title)
92
+ plt.xlabel("True Labels")
93
+ plt.ylabel("Predicted Probability")
94
+ plt.legend()
95
+
96
+ # Save and show the plot
97
+ plt.tight_layout()
98
+ plt.savefig(output_file)
99
+ plt.show()
100
+
101
+ seed_everything(42)
102
+
103
+ dataset = load_from_disk(f'{base_path}/data/solubility')
104
+
105
+ sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings
106
+ labels = np.stack(dataset['labels'])
107
+ embeddings = np.stack(dataset['embedding'])
108
+
109
+ # Initialize best F1 score and model path
110
+ best_f1 = -np.inf
111
+ best_model_path = f"{base_path}/src/solubility"
112
+
113
+ # Trial callback
114
+ def trial_info_callback(study, trial):
115
+ if study.best_trial == trial:
116
+ print(f"Trial {trial.number}:")
117
+ print(f" Weighted F1 Score: {trial.value}")
118
+
119
+
120
+
121
+ def objective(trial):
122
+ # Define hyperparameters
123
+ params = {
124
+ 'objective': 'binary:logistic',
125
+ 'lambda': trial.suggest_float('lambda', 1e-8, 50.0, log=True),
126
+ 'alpha': trial.suggest_float('alpha', 1e-8, 50.0, log=True),
127
+ 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0),
128
+ 'subsample': trial.suggest_float('subsample', 0.5, 1.0),
129
+ 'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.3),
130
+ 'max_depth': trial.suggest_int('max_depth', 2, 15),
131
+ 'min_child_weight': trial.suggest_int('min_child_weight', 1, 500),
132
+ 'gamma': trial.suggest_float('gamma', 0, 10.0),
133
+ 'tree_method': 'hist',
134
+ 'device': 'cuda:6',
135
+ }
136
+
137
+ # Suggest number of boosting rounds
138
+ num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
139
+ threshold = 0.5 # Initial classification threshold
140
+
141
+ # Split the data
142
+ train_idx, val_idx = train_test_split(
143
+ np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42
144
+ )
145
+ train_subset = dataset.select(train_idx).with_format("torch")
146
+ val_subset = dataset.select(val_idx).with_format("torch")
147
+
148
+ # Extract embeddings and labels for train/validation
149
+ train_embeddings = np.array(train_subset['embedding'])
150
+ valid_embeddings = np.array(val_subset['embedding'])
151
+ train_labels = np.array(train_subset['labels'])
152
+ valid_labels = np.array(val_subset['labels'])
153
+
154
+ # Prepare training and validation sets
155
+ dtrain = xgb.DMatrix(train_embeddings, label=train_labels)
156
+ dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels)
157
+
158
+ # Train the model
159
+ model = xgb.train(
160
+ params=params,
161
+ dtrain=dtrain,
162
+ num_boost_round=num_boost_round,
163
+ evals=[(dvalid, "validation")],
164
+ early_stopping_rounds=50,
165
+ verbose_eval=False,
166
+ )
167
+
168
+ # Predict probabilities
169
+ preds_train = model.predict(dtrain)
170
+ preds_val = model.predict(dvalid)
171
+
172
+ # Calculate metrics
173
+ f1_val = f1_score(valid_labels, (preds_val >= threshold).astype(int), average="weighted")
174
+ auc_val = roc_auc_score(valid_labels, preds_val)
175
+ print(f"Trial {trial.number}: AUC: {auc_val:.3f}, F1 Score: {f1_val:.3f}")
176
+
177
+ # Save the model if it has the best F1 score
178
+ current_best = trial.study.user_attrs.get("best_f1", -np.inf)
179
+ if f1_val > current_best:
180
+ trial.study.set_user_attr("best_f1", f1_val)
181
+ trial.study.set_user_attr("best_auc", auc_val)
182
+ trial.study.set_user_attr("best_trial", trial.number)
183
+ os.makedirs(best_model_path, exist_ok=True)
184
+
185
+ # Save the model
186
+ model.save_model(os.path.join(best_model_path, "best_model_f1.json"))
187
+ print(f"✓ NEW BEST! Trial {trial.number}: F1={f1_val:.4f}, AUC={auc_val:.4f} - Model saved!")
188
+
189
+ # Save and plot binary predictions
190
+ save_and_plot_binary_predictions(
191
+ train_labels, preds_train, valid_labels, preds_val, threshold, best_model_path
192
+ )
193
+
194
+ return f1_val
195
+
196
+ if __name__ == "__main__":
197
+ study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
198
+ study.optimize(objective, n_trials=200)
199
+
200
+ # Prepare summary text
201
+ summary = []
202
+ summary.append("\n" + "="*60)
203
+ summary.append("OPTIMIZATION COMPLETE")
204
+ summary.append("="*60)
205
+ summary.append(f"Number of finished trials: {len(study.trials)}")
206
+ summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
207
+ summary.append(f"Best F1 Score: {study.user_attrs.get('best_f1', None):.4f}")
208
+ summary.append(f"Best AUC Score: {study.user_attrs.get('best_auc', None):.4f}")
209
+ summary.append(f"Optuna Best Trial Value: {study.best_trial.value:.4f}")
210
+ summary.append(f"\nBest hyperparameters:")
211
+ for key, value in study.best_trial.params.items():
212
+ summary.append(f" {key}: {value}")
213
+ summary.append("="*60)
214
+
215
+ # Print to console
216
+ for line in summary:
217
+ print(line)
218
+
219
+ # Save to file
220
+ metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
221
+ with open(metrics_file, 'w') as f:
222
+ f.write('\n'.join(summary))
223
+ print(f"\n✓ Metrics saved to: {metrics_file}")
train/permeability_xg.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import optuna
4
+ from optuna.trial import TrialState
5
+ from rdkit import Chem
6
+ from rdkit.Chem import AllChem
7
+ from sklearn.metrics import mean_squared_error
8
+ from sklearn.model_selection import train_test_split
9
+ import xgboost as xgb
10
+ import os
11
+ from datasets import load_from_disk
12
+ from scipy.stats import spearmanr
13
+ import matplotlib.pyplot as plt
14
+
15
+ base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
16
+
17
+ def save_and_plot_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, output_path):
18
+ os.makedirs(output_path, exist_ok=True)
19
+
20
+ # Save training predictions
21
+ train_df = pd.DataFrame({'True Permeability': y_true_train, 'Predicted Permeability': y_pred_train})
22
+ train_df.to_csv(os.path.join(output_path, 'train_predictions.csv'), index=False)
23
+
24
+ # Save validation predictions
25
+ val_df = pd.DataFrame({'True Permeability': y_true_val, 'Predicted Permeability': y_pred_val})
26
+ val_df.to_csv(os.path.join(output_path, 'val_predictions.csv'), index=False)
27
+
28
+ # Plot training predictions
29
+ plot_correlation(
30
+ y_true_train,
31
+ y_pred_train,
32
+ title="Training Set Correlation Plot",
33
+ output_file=os.path.join(output_path, 'train_correlation.png'),
34
+ )
35
+
36
+ # Plot validation predictions
37
+ plot_correlation(
38
+ y_true_val,
39
+ y_pred_val,
40
+ title="Validation Set Correlation Plot",
41
+ output_file=os.path.join(output_path, 'val_correlation.png'),
42
+ )
43
+
44
+ def plot_correlation(y_true, y_pred, title, output_file):
45
+ spearman_corr, _ = spearmanr(y_true, y_pred)
46
+
47
+ # Scatter plot
48
+ plt.figure(figsize=(10, 8))
49
+ plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
50
+ plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='teal', linestyle='--', label='Ideal fit')
51
+
52
+ # Add annotations
53
+ plt.title(f"{title}\nSpearman Correlation: {spearman_corr:.3f}")
54
+ plt.xlabel("True Permeability (logP)")
55
+ plt.ylabel("Predicted Affinity (logP)")
56
+ plt.legend()
57
+
58
+ # Save and show the plot
59
+ plt.tight_layout()
60
+ plt.savefig(output_file)
61
+ plt.show()
62
+
63
+ # Load dataset
64
+ dataset = load_from_disk(f'{base_path}/data/permeability')
65
+
66
+ # Extract sequences, labels, and embeddings
67
+ sequences = np.stack(dataset['sequence'])
68
+ labels = np.stack(dataset['labels']) # Regression labels
69
+ embeddings = np.stack(dataset['embedding']) # Pre-trained embeddings
70
+
71
+ # Function to compute Morgan fingerprints
72
+ def compute_morgan_fingerprints(smiles_list, radius=2, n_bits=2048):
73
+ fps = []
74
+ for smiles in smiles_list:
75
+ mol = Chem.MolFromSmiles(smiles)
76
+ if mol is not None:
77
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
78
+ fps.append(np.array(fp))
79
+ else:
80
+ # If the SMILES string is invalid, use a zero vector
81
+ fps.append(np.zeros(n_bits))
82
+ print(f"Invalid SMILES: {smiles}")
83
+ return np.array(fps)
84
+
85
+ # Compute Morgan fingerprints for the sequences
86
+ #morgan_fingerprints = compute_morgan_fingerprints(sequences)
87
+
88
+ # Concatenate embeddings with Morgan fingerprints
89
+ #input_features = np.concatenate([embeddings, morgan_fingerprints], axis=1)
90
+ input_features = embeddings
91
+
92
+ # Initialize global variables
93
+ best_model_path = f"{base_path}/src/permeability"
94
+ os.makedirs(best_model_path, exist_ok=True)
95
+
96
+ def trial_info_callback(study, trial):
97
+ if study.best_trial == trial:
98
+ print(f"Trial {trial.number}:")
99
+ print(f" MSE: {trial.value}")
100
+
101
+ def objective(trial):
102
+ # Define hyperparameters
103
+ params = {
104
+ 'objective': 'reg:squarederror',
105
+ 'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True),
106
+ 'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True),
107
+ 'gamma': trial.suggest_float('gamma', 0, 5),
108
+ 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
109
+ 'subsample': trial.suggest_float('subsample', 0.6, 0.9),
110
+ 'learning_rate': trial.suggest_float('learning_rate', 1e-5, 0.1),
111
+ 'max_depth': trial.suggest_int('max_depth', 2, 30),
112
+ 'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
113
+ 'tree_method': 'hist',
114
+ 'scale_pos_weight': trial.suggest_float('scale_pos_weight', 0.5, 10.0, log=True),
115
+ 'device': 'cuda:6',
116
+ }
117
+ num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
118
+
119
+ # Train-validation split
120
+ X_train, X_val, y_train, y_val = train_test_split(input_features, labels, test_size=0.2, random_state=42)
121
+
122
+ # Convert data to DMatrix
123
+ dtrain = xgb.DMatrix(X_train, label=y_train)
124
+ dvalid = xgb.DMatrix(X_val, label=y_val)
125
+
126
+ # Train XGBoost
127
+ model = xgb.train(
128
+ params=params,
129
+ dtrain=dtrain,
130
+ num_boost_round=num_boost_round,
131
+ evals=[(dvalid, "validation")],
132
+ early_stopping_rounds=50,
133
+ verbose_eval=False,
134
+ )
135
+
136
+ # Predict and evaluate
137
+ preds_train = model.predict(dtrain)
138
+ preds_val = model.predict(dvalid)
139
+
140
+ mse = mean_squared_error(y_val, preds_val)
141
+
142
+ # Calculate Spearman Rank Correlation for both train and validation
143
+ spearman_train, _ = spearmanr(y_train, preds_train)
144
+ spearman_val, _ = spearmanr(y_val, preds_val)
145
+ print(f"Train Spearman: {spearman_train:.4f}, Val Spearman: {spearman_val:.4f}")
146
+
147
+ # Save the best model
148
+ if trial.study.user_attrs.get("best_mse", np.inf) > mse:
149
+ trial.study.set_user_attr("best_mse", mse)
150
+ trial.study.set_user_attr("best_spearman_train", spearman_train)
151
+ trial.study.set_user_attr("best_spearman_val", spearman_val)
152
+ trial.study.set_user_attr("best_trial", trial.number)
153
+ model.save_model(os.path.join(best_model_path, "best_model.json"))
154
+ save_and_plot_predictions(y_train, preds_train, y_val, preds_val, best_model_path)
155
+ print(f"✓ NEW BEST! Trial {trial.number}: MSE={mse:.4f}, Train Spearman={spearman_train:.4f}, Val Spearman={spearman_val:.4f}")
156
+
157
+ return mse
158
+
159
+ if __name__ == "__main__":
160
+ study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
161
+ study.optimize(objective, n_trials=200, callbacks=[trial_info_callback])
162
+
163
+ # Prepare summary text
164
+ summary = []
165
+ summary.append("\n" + "="*60)
166
+ summary.append("OPTIMIZATION COMPLETE")
167
+ summary.append("="*60)
168
+ summary.append(f"Number of finished trials: {len(study.trials)}")
169
+ summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
170
+ summary.append(f"Best MSE: {study.best_trial.value:.4f}")
171
+ summary.append(f"Best Training Spearman Correlation: {study.user_attrs.get('best_spearman_train', None):.4f}")
172
+ summary.append(f"Best Validation Spearman Correlation: {study.user_attrs.get('best_spearman_val', None):.4f}")
173
+ summary.append(f"\nBest hyperparameters:")
174
+ for key, value in study.best_trial.params.items():
175
+ summary.append(f" {key}: {value}")
176
+ summary.append("="*60)
177
+
178
+ # Print to console
179
+ for line in summary:
180
+ print(line)
181
+
182
+ # Save to file
183
+ metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
184
+ with open(metrics_file, 'w') as f:
185
+ f.write('\n'.join(summary))
186
+ print(f"\n✓ Metrics saved to: {metrics_file}")