Commit
·
813c6b1
1
Parent(s):
3e730f5
add functions
Browse files- README.md +2 -2
- functions/binding.py +186 -0
- functions/hemolysis.py +69 -0
- functions/nonfouling.py +69 -0
- functions/permeability.py +167 -0
- functions/solubility.py +68 -0
- functions/tokenizer/__pycache__/my_tokenizers.cpython-310.pyc +0 -0
- functions/tokenizer/my_tokenizers.py +398 -0
- functions/tokenizer/new_splits.txt +159 -0
- functions/tokenizer/new_vocab.txt +586 -0
- scoring_functions.py +103 -0
- train/binary_xg.py +223 -0
- train/permeability_xg.py +186 -0
README.md
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b4a57e9caf84b0991a9a349cb28b44049995f4a51ccc3118a0114baf856f36a
|
| 3 |
+
size 839
|
functions/binding.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import esm
|
| 6 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 7 |
+
from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModel
|
| 8 |
+
|
| 9 |
+
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 10 |
+
|
| 11 |
+
class ImprovedBindingPredictor(nn.Module):
|
| 12 |
+
def __init__(self,
|
| 13 |
+
esm_dim=1280,
|
| 14 |
+
smiles_dim=768,
|
| 15 |
+
hidden_dim=512,
|
| 16 |
+
n_heads=8,
|
| 17 |
+
n_layers=3,
|
| 18 |
+
dropout=0.1):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
# Define binding thresholds
|
| 22 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 23 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 24 |
+
|
| 25 |
+
# Project to same dimension
|
| 26 |
+
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
|
| 27 |
+
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
|
| 28 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 29 |
+
self.smiles_norm = nn.LayerNorm(hidden_dim)
|
| 30 |
+
|
| 31 |
+
# Cross attention blocks with layer norm
|
| 32 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 33 |
+
nn.ModuleDict({
|
| 34 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 35 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 36 |
+
'ffn': nn.Sequential(
|
| 37 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 38 |
+
nn.ReLU(),
|
| 39 |
+
nn.Dropout(dropout),
|
| 40 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 41 |
+
),
|
| 42 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 43 |
+
}) for _ in range(n_layers)
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
# Prediction heads
|
| 47 |
+
self.shared_head = nn.Sequential(
|
| 48 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Dropout(dropout),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Regression head
|
| 54 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 55 |
+
|
| 56 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 57 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 58 |
+
|
| 59 |
+
def get_binding_class(self, affinity):
|
| 60 |
+
"""Convert affinity values to class indices
|
| 61 |
+
0: tight binding (>= 7.5)
|
| 62 |
+
1: medium binding (6.0-7.5)
|
| 63 |
+
2: weak binding (< 6.0)
|
| 64 |
+
"""
|
| 65 |
+
if isinstance(affinity, torch.Tensor):
|
| 66 |
+
tight_mask = affinity >= self.tight_threshold
|
| 67 |
+
weak_mask = affinity < self.weak_threshold
|
| 68 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 69 |
+
|
| 70 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 71 |
+
classes[medium_mask] = 1
|
| 72 |
+
classes[weak_mask] = 2
|
| 73 |
+
return classes
|
| 74 |
+
else:
|
| 75 |
+
if affinity >= self.tight_threshold:
|
| 76 |
+
return 0 # tight binding
|
| 77 |
+
elif affinity < self.weak_threshold:
|
| 78 |
+
return 2 # weak binding
|
| 79 |
+
else:
|
| 80 |
+
return 1 # medium binding
|
| 81 |
+
|
| 82 |
+
def forward(self, protein_emb, smiles_emb):
|
| 83 |
+
protein = self.protein_norm(self.protein_projection(protein_emb))
|
| 84 |
+
smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
|
| 85 |
+
|
| 86 |
+
#protein = protein.transpose(0, 1)
|
| 87 |
+
#smiles = smiles.transpose(0, 1)
|
| 88 |
+
|
| 89 |
+
# Cross attention layers
|
| 90 |
+
for layer in self.cross_attention_layers:
|
| 91 |
+
# Protein attending to SMILES
|
| 92 |
+
attended_protein = layer['attention'](
|
| 93 |
+
protein, smiles, smiles
|
| 94 |
+
)[0]
|
| 95 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 96 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 97 |
+
|
| 98 |
+
# SMILES attending to protein
|
| 99 |
+
attended_smiles = layer['attention'](
|
| 100 |
+
smiles, protein, protein
|
| 101 |
+
)[0]
|
| 102 |
+
smiles = layer['norm1'](smiles + attended_smiles)
|
| 103 |
+
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
|
| 104 |
+
|
| 105 |
+
# Get sequence-level representations
|
| 106 |
+
protein_pool = torch.mean(protein, dim=0)
|
| 107 |
+
smiles_pool = torch.mean(smiles, dim=0)
|
| 108 |
+
|
| 109 |
+
# Concatenate both representations
|
| 110 |
+
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
|
| 111 |
+
|
| 112 |
+
# Shared features
|
| 113 |
+
shared_features = self.shared_head(combined)
|
| 114 |
+
|
| 115 |
+
regression_output = self.regression_head(shared_features)
|
| 116 |
+
classification_logits = self.classification_head(shared_features)
|
| 117 |
+
|
| 118 |
+
return regression_output, classification_logits
|
| 119 |
+
|
| 120 |
+
class BindingAffinity:
|
| 121 |
+
def __init__(self, prot_seq, model_type='PeptideCLM'):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
# peptide embeddings
|
| 125 |
+
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 126 |
+
self.pep_tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
|
| 127 |
+
f'{base_path}/functions/tokenizer/new_splits.txt')
|
| 128 |
+
self.model = ImprovedBindingPredictor()
|
| 129 |
+
checkpoint = torch.load(f'{base_path}/src/binding/best_model.pt', weights_only=False)
|
| 130 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 131 |
+
|
| 132 |
+
self.model.eval()
|
| 133 |
+
|
| 134 |
+
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
|
| 135 |
+
self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
|
| 136 |
+
|
| 137 |
+
data = [("target", prot_seq)]
|
| 138 |
+
# get tokenized protein
|
| 139 |
+
_, _, prot_tokens = self.prot_tokenizer(data)
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
|
| 142 |
+
prot_emb = results["representations"][33]
|
| 143 |
+
|
| 144 |
+
self.prot_emb = prot_emb[0]
|
| 145 |
+
self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def forward(self, input_seqs):
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
scores = []
|
| 151 |
+
for seq in input_seqs:
|
| 152 |
+
pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
emb = self.pep_model(input_ids=pep_tokens['input_ids'],
|
| 156 |
+
attention_mask=pep_tokens['attention_mask'],
|
| 157 |
+
output_hidden_states=True)
|
| 158 |
+
|
| 159 |
+
#emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
|
| 160 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 161 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 162 |
+
|
| 163 |
+
score, logits = self.model.forward(self.prot_emb, pep_emb)
|
| 164 |
+
scores.append(score.item())
|
| 165 |
+
return scores
|
| 166 |
+
|
| 167 |
+
def __call__(self, input_seqs: list):
|
| 168 |
+
return self.forward(input_seqs)
|
| 169 |
+
|
| 170 |
+
def unittest():
|
| 171 |
+
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
|
| 172 |
+
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
|
| 173 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 174 |
+
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
|
| 175 |
+
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
|
| 176 |
+
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
|
| 177 |
+
|
| 178 |
+
binding = BindingAffinity(tfr)
|
| 179 |
+
|
| 180 |
+
seq = ["CC[C@H](C)[C@H](NC(=O)[C@H](C)NC(=O)[C@@H](N)Cc1c[nH]cn1)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)O"]
|
| 181 |
+
|
| 182 |
+
scores = binding(seq)
|
| 183 |
+
print(scores)
|
| 184 |
+
|
| 185 |
+
if __name__ == '__main__':
|
| 186 |
+
unittest()
|
functions/hemolysis.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 8 |
+
import warnings
|
| 9 |
+
import numpy as np
|
| 10 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 11 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 12 |
+
from rdkit.Chem import AllChem
|
| 13 |
+
from typing import List
|
| 14 |
+
|
| 15 |
+
rdBase.DisableLog('rdApp.error')
|
| 16 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 17 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 19 |
+
|
| 20 |
+
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 21 |
+
|
| 22 |
+
class Hemolysis:
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/src/best_model_f1.json')
|
| 26 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 27 |
+
self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
|
| 28 |
+
f'{base_path}/functions/tokenizer/new_splits.txt')
|
| 29 |
+
def generate_embeddings(self, sequences):
|
| 30 |
+
embeddings = []
|
| 31 |
+
for sequence in sequences:
|
| 32 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
output = self.emb_model(**tokenized)
|
| 35 |
+
# Mean pooling across sequence length
|
| 36 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 37 |
+
embeddings.append(embedding)
|
| 38 |
+
return np.array(embeddings)
|
| 39 |
+
|
| 40 |
+
def get_scores(self, input_seqs: list):
|
| 41 |
+
scores = np.ones(len(input_seqs))
|
| 42 |
+
features = self.generate_embeddings(input_seqs)
|
| 43 |
+
|
| 44 |
+
if len(features) == 0:
|
| 45 |
+
return scores
|
| 46 |
+
|
| 47 |
+
features = np.nan_to_num(features, nan=0.)
|
| 48 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 49 |
+
|
| 50 |
+
features = xgb.DMatrix(features)
|
| 51 |
+
|
| 52 |
+
probs = self.predictor.predict(features)
|
| 53 |
+
# return the probability of it being not hemolytic
|
| 54 |
+
return scores - probs
|
| 55 |
+
|
| 56 |
+
def __call__(self, input_seqs: list):
|
| 57 |
+
scores = self.get_scores(input_seqs)
|
| 58 |
+
return scores
|
| 59 |
+
|
| 60 |
+
def unittest():
|
| 61 |
+
hemo = Hemolysis()
|
| 62 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 63 |
+
|
| 64 |
+
scores = hemo(input_seqs=seq)
|
| 65 |
+
print(scores)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == '__main__':
|
| 69 |
+
unittest()
|
functions/nonfouling.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 8 |
+
import warnings
|
| 9 |
+
import numpy as np
|
| 10 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 11 |
+
from transformers import AutoModelForMaskedLM
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
rdBase.DisableLog('rdApp.error')
|
| 15 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 16 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 17 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 18 |
+
|
| 19 |
+
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 20 |
+
|
| 21 |
+
class Nonfouling:
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/src/nonfouling/best_model_f1.json')
|
| 25 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 26 |
+
self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
|
| 27 |
+
f'{base_path}/functions/tokenizer/new_splits.txt')
|
| 28 |
+
|
| 29 |
+
def generate_embeddings(self, sequences):
|
| 30 |
+
embeddings = []
|
| 31 |
+
for sequence in sequences:
|
| 32 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
output = self.emb_model(**tokenized)
|
| 35 |
+
# Mean pooling across sequence length
|
| 36 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 37 |
+
embeddings.append(embedding)
|
| 38 |
+
return np.array(embeddings)
|
| 39 |
+
|
| 40 |
+
def get_scores(self, input_seqs: list):
|
| 41 |
+
scores = np.zeros(len(input_seqs))
|
| 42 |
+
features = self.generate_embeddings(input_seqs)
|
| 43 |
+
|
| 44 |
+
if len(features) == 0:
|
| 45 |
+
return scores
|
| 46 |
+
|
| 47 |
+
features = np.nan_to_num(features, nan=0.)
|
| 48 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 49 |
+
|
| 50 |
+
features = xgb.DMatrix(features)
|
| 51 |
+
|
| 52 |
+
scores = self.predictor.predict(features)
|
| 53 |
+
# return the probability of it being not hemolytic
|
| 54 |
+
return scores
|
| 55 |
+
|
| 56 |
+
def __call__(self, input_seqs: list):
|
| 57 |
+
scores = self.get_scores(input_seqs)
|
| 58 |
+
return scores
|
| 59 |
+
|
| 60 |
+
def unittest():
|
| 61 |
+
nf = Nonfouling()
|
| 62 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 63 |
+
|
| 64 |
+
scores = nf(input_seqs=seq)
|
| 65 |
+
print(scores)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == '__main__':
|
| 69 |
+
unittest()
|
functions/permeability.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 8 |
+
import warnings
|
| 9 |
+
import numpy as np
|
| 10 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 11 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 12 |
+
from rdkit.Chem import AllChem
|
| 13 |
+
from typing import List
|
| 14 |
+
|
| 15 |
+
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 16 |
+
|
| 17 |
+
rdBase.DisableLog('rdApp.error')
|
| 18 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 19 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 20 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 21 |
+
|
| 22 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 23 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 24 |
+
fps = []
|
| 25 |
+
valid_mask = []
|
| 26 |
+
for i, smile in enumerate(smiles):
|
| 27 |
+
mol = Chem.MolFromSmiles(smile)
|
| 28 |
+
valid_mask.append(int(mol is not None))
|
| 29 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 30 |
+
fps.append(fp)
|
| 31 |
+
|
| 32 |
+
fps = np.concatenate(fps, axis=0)
|
| 33 |
+
return fps, valid_mask
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 37 |
+
""" Create ECFP fingerprint of a molecule """
|
| 38 |
+
if hashed:
|
| 39 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 40 |
+
else:
|
| 41 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 42 |
+
fp_np = np.zeros((1,))
|
| 43 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 44 |
+
return fp_np.reshape(1, -1)
|
| 45 |
+
|
| 46 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 47 |
+
""" calculate the full list of descriptors for a molecule """
|
| 48 |
+
|
| 49 |
+
values, names = [], []
|
| 50 |
+
for nm, fn in Descriptors._descList:
|
| 51 |
+
try:
|
| 52 |
+
val = fn(mol)
|
| 53 |
+
except:
|
| 54 |
+
val = missingVal
|
| 55 |
+
values.append(val)
|
| 56 |
+
names.append(nm)
|
| 57 |
+
|
| 58 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 59 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 60 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 61 |
+
|
| 62 |
+
for nm, fn in custom_descriptors.items():
|
| 63 |
+
try:
|
| 64 |
+
val = fn(mol)
|
| 65 |
+
except:
|
| 66 |
+
val = missingVal
|
| 67 |
+
values.append(val)
|
| 68 |
+
names.append(nm)
|
| 69 |
+
return values, names
|
| 70 |
+
|
| 71 |
+
def get_pep_dps_from_smi(smi):
|
| 72 |
+
try:
|
| 73 |
+
mol = Chem.MolFromSmiles(smi)
|
| 74 |
+
except:
|
| 75 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 76 |
+
mol = None
|
| 77 |
+
|
| 78 |
+
dps, _ = getMolDescriptors(mol)
|
| 79 |
+
return np.array(dps)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_pep_dps(smi_list):
|
| 83 |
+
if len(smi_list) == 0:
|
| 84 |
+
return np.zeros((0, 213))
|
| 85 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 86 |
+
|
| 87 |
+
def check_smi_validity(smiles: list):
|
| 88 |
+
valid_smi, valid_idx = [], []
|
| 89 |
+
for idx, smi in enumerate(smiles):
|
| 90 |
+
try:
|
| 91 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 92 |
+
if mol:
|
| 93 |
+
valid_smi.append(smi)
|
| 94 |
+
valid_idx.append(idx)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 97 |
+
pass
|
| 98 |
+
return valid_smi, valid_idx
|
| 99 |
+
|
| 100 |
+
class Permeability:
|
| 101 |
+
|
| 102 |
+
def __init__(self):
|
| 103 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/src/permeability/best_model.json')
|
| 104 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 105 |
+
self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
|
| 106 |
+
f'{base_path}/functions/tokenizer/new_splits.txt')
|
| 107 |
+
|
| 108 |
+
def generate_embeddings(self, sequences):
|
| 109 |
+
embeddings = []
|
| 110 |
+
for sequence in sequences:
|
| 111 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
output = self.emb_model(**tokenized)
|
| 114 |
+
# Mean pooling across sequence length
|
| 115 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 116 |
+
embeddings.append(embedding)
|
| 117 |
+
return np.array(embeddings)
|
| 118 |
+
|
| 119 |
+
def get_features(self, input_seqs: list, dps=False, fps=False):
|
| 120 |
+
#valid_smiles, valid_idxes = check_smi_validity(input_seqs)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if fps:
|
| 124 |
+
fingerprints = fingerprints_from_smiles(input_seqs)[0]
|
| 125 |
+
else:
|
| 126 |
+
fingerprints = torch.empty((len(input_seqs), 0))
|
| 127 |
+
|
| 128 |
+
if dps:
|
| 129 |
+
descriptors = get_pep_dps(input_seqs)
|
| 130 |
+
else:
|
| 131 |
+
descriptors = torch.empty((len(input_seqs), 0))
|
| 132 |
+
|
| 133 |
+
embeddings = self.generate_embeddings(input_seqs)
|
| 134 |
+
# logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
|
| 135 |
+
|
| 136 |
+
features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
|
| 137 |
+
|
| 138 |
+
return features
|
| 139 |
+
|
| 140 |
+
def get_scores(self, input_seqs: list):
|
| 141 |
+
scores = -10 * np.ones(len(input_seqs))
|
| 142 |
+
features = self.get_features(input_seqs)
|
| 143 |
+
|
| 144 |
+
if len(features) == 0:
|
| 145 |
+
return scores
|
| 146 |
+
|
| 147 |
+
features = np.nan_to_num(features, nan=0.)
|
| 148 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 149 |
+
|
| 150 |
+
features = xgb.DMatrix(features)
|
| 151 |
+
|
| 152 |
+
scores = self.predictor.predict(features)
|
| 153 |
+
return scores
|
| 154 |
+
|
| 155 |
+
def __call__(self, input_seqs: list):
|
| 156 |
+
scores = self.get_scores(input_seqs)
|
| 157 |
+
return scores
|
| 158 |
+
|
| 159 |
+
def unittest():
|
| 160 |
+
permeability = Permeability()
|
| 161 |
+
seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
|
| 162 |
+
scores = permeability(input_seqs=seq)
|
| 163 |
+
print(scores)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if __name__ == '__main__':
|
| 167 |
+
unittest()
|
functions/solubility.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 8 |
+
import warnings
|
| 9 |
+
import numpy as np
|
| 10 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 11 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 12 |
+
from rdkit.Chem import AllChem
|
| 13 |
+
from typing import List
|
| 14 |
+
from transformers import AutoModelForMaskedLM
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
rdBase.DisableLog('rdApp.error')
|
| 18 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 19 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 20 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 21 |
+
|
| 22 |
+
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 23 |
+
|
| 24 |
+
class Solubility:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/src/solubility/best_model_f1.json')
|
| 27 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 28 |
+
self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt',
|
| 29 |
+
f'{base_path}/functions/tokenizer/new_splits.txt')
|
| 30 |
+
|
| 31 |
+
def generate_embeddings(self, sequences):
|
| 32 |
+
embeddings = []
|
| 33 |
+
for sequence in sequences:
|
| 34 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
output = self.emb_model(**tokenized)
|
| 37 |
+
# Mean pooling across sequence length
|
| 38 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 39 |
+
embeddings.append(embedding)
|
| 40 |
+
return np.array(embeddings)
|
| 41 |
+
|
| 42 |
+
def get_scores(self, input_seqs: list):
|
| 43 |
+
scores = np.zeros(len(input_seqs))
|
| 44 |
+
features = self.generate_embeddings(input_seqs)
|
| 45 |
+
|
| 46 |
+
if len(features) == 0:
|
| 47 |
+
return scores
|
| 48 |
+
|
| 49 |
+
features = np.nan_to_num(features, nan=0.)
|
| 50 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 51 |
+
|
| 52 |
+
features = xgb.DMatrix(features)
|
| 53 |
+
|
| 54 |
+
scores = self.predictor.predict(features)
|
| 55 |
+
return scores
|
| 56 |
+
|
| 57 |
+
def __call__(self, input_seqs: list):
|
| 58 |
+
scores = self.get_scores(input_seqs)
|
| 59 |
+
return scores
|
| 60 |
+
|
| 61 |
+
def unittest():
|
| 62 |
+
solubility = Solubility()
|
| 63 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 64 |
+
scores = solubility(input_seqs=seq)
|
| 65 |
+
print(scores)
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
unittest()
|
functions/tokenizer/__pycache__/my_tokenizers.cpython-310.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
functions/tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import codecs
|
| 6 |
+
import unicodedata
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from transformers import PreTrainedTokenizer
|
| 9 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 10 |
+
|
| 11 |
+
def load_vocab(vocab_file):
|
| 12 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 13 |
+
vocab = collections.OrderedDict()
|
| 14 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 15 |
+
tokens = reader.readlines()
|
| 16 |
+
for index, token in enumerate(tokens):
|
| 17 |
+
token = token.rstrip("\n")
|
| 18 |
+
vocab[token] = index
|
| 19 |
+
return vocab
|
| 20 |
+
|
| 21 |
+
class Atomwise_Tokenizer(object):
|
| 22 |
+
"""Run atom-level SMILES tokenization"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
""" Constructs a atom-level Tokenizer.
|
| 26 |
+
"""
|
| 27 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 28 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 29 |
+
|
| 30 |
+
self.regex = re.compile(self.regex_pattern)
|
| 31 |
+
|
| 32 |
+
def tokenize(self, text):
|
| 33 |
+
""" Basic Tokenization of a SMILES.
|
| 34 |
+
"""
|
| 35 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 36 |
+
return tokens
|
| 37 |
+
|
| 38 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 39 |
+
r"""
|
| 40 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 41 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 42 |
+
should refer to the superclass for more information regarding methods.
|
| 43 |
+
Args:
|
| 44 |
+
vocab_file (:obj:`string`):
|
| 45 |
+
File containing the vocabulary.
|
| 46 |
+
spe_file (:obj:`string`):
|
| 47 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 48 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 49 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 50 |
+
token instead.
|
| 51 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 52 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 53 |
+
for sequence classification or for a text and a question for question answering.
|
| 54 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 55 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 56 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 57 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 58 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 59 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 60 |
+
special tokens.
|
| 61 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 62 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 63 |
+
modeling. This is the token which the model will try to predict.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, vocab_file, spe_file,
|
| 67 |
+
unk_token="[UNK]",
|
| 68 |
+
sep_token="[SEP]",
|
| 69 |
+
pad_token="[PAD]",
|
| 70 |
+
cls_token="[CLS]",
|
| 71 |
+
mask_token="[MASK]",
|
| 72 |
+
**kwargs):
|
| 73 |
+
if not os.path.isfile(vocab_file):
|
| 74 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 75 |
+
if not os.path.isfile(spe_file):
|
| 76 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 77 |
+
|
| 78 |
+
self.vocab = load_vocab(vocab_file)
|
| 79 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 80 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 81 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 82 |
+
|
| 83 |
+
super().__init__(
|
| 84 |
+
unk_token=unk_token,
|
| 85 |
+
sep_token=sep_token,
|
| 86 |
+
pad_token=pad_token,
|
| 87 |
+
cls_token=cls_token,
|
| 88 |
+
mask_token=mask_token,
|
| 89 |
+
**kwargs)
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def vocab_size(self):
|
| 93 |
+
return len(self.vocab)
|
| 94 |
+
|
| 95 |
+
def get_vocab(self):
|
| 96 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 97 |
+
|
| 98 |
+
def _tokenize(self, text):
|
| 99 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 100 |
+
|
| 101 |
+
def _convert_token_to_id(self, token):
|
| 102 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 103 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 104 |
+
|
| 105 |
+
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
| 106 |
+
text = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
| 107 |
+
return self.convert_tokens_to_string(text)
|
| 108 |
+
|
| 109 |
+
def _convert_id_to_token(self, index):
|
| 110 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 111 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 112 |
+
|
| 113 |
+
def convert_tokens_to_string(self, tokens):
|
| 114 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 115 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 116 |
+
return out_string
|
| 117 |
+
|
| 118 |
+
def build_inputs_with_special_tokens(
|
| 119 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 120 |
+
) -> List[int]:
|
| 121 |
+
"""
|
| 122 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 123 |
+
by concatenating and adding special tokens.
|
| 124 |
+
A BERT sequence has the following format:
|
| 125 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 126 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 127 |
+
Args:
|
| 128 |
+
token_ids_0 (:obj:`List[int]`):
|
| 129 |
+
List of IDs to which the special tokens will be added
|
| 130 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 131 |
+
Optional second list of IDs for sequence pairs.
|
| 132 |
+
Returns:
|
| 133 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 134 |
+
"""
|
| 135 |
+
if token_ids_1 is None:
|
| 136 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 137 |
+
cls = [self.cls_token_id]
|
| 138 |
+
sep = [self.sep_token_id]
|
| 139 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 140 |
+
|
| 141 |
+
def get_special_tokens_mask(
|
| 142 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 143 |
+
) -> List[int]:
|
| 144 |
+
"""
|
| 145 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 146 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 147 |
+
Args:
|
| 148 |
+
token_ids_0 (:obj:`List[int]`):
|
| 149 |
+
List of ids.
|
| 150 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 151 |
+
Optional second list of IDs for sequence pairs.
|
| 152 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 153 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 154 |
+
Returns:
|
| 155 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
if already_has_special_tokens:
|
| 159 |
+
if token_ids_1 is not None:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 162 |
+
"ids is already formated with special tokens for the model."
|
| 163 |
+
)
|
| 164 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 165 |
+
|
| 166 |
+
if token_ids_1 is not None:
|
| 167 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 168 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 169 |
+
|
| 170 |
+
def create_token_type_ids_from_sequences(
|
| 171 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 172 |
+
) -> List[int]:
|
| 173 |
+
"""
|
| 174 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 175 |
+
A BERT sequence pair mask has the following format:
|
| 176 |
+
::
|
| 177 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 178 |
+
| first sequence | second sequence |
|
| 179 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 180 |
+
Args:
|
| 181 |
+
token_ids_0 (:obj:`List[int]`):
|
| 182 |
+
List of ids.
|
| 183 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 184 |
+
Optional second list of IDs for sequence pairs.
|
| 185 |
+
Returns:
|
| 186 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 187 |
+
sequence(s).
|
| 188 |
+
"""
|
| 189 |
+
sep = [self.sep_token_id]
|
| 190 |
+
cls = [self.cls_token_id]
|
| 191 |
+
if token_ids_1 is None:
|
| 192 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 193 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 194 |
+
|
| 195 |
+
def save_vocabulary(self, vocab_path):
|
| 196 |
+
"""
|
| 197 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 198 |
+
Args:
|
| 199 |
+
vocab_path (:obj:`str`):
|
| 200 |
+
The directory in which to save the vocabulary.
|
| 201 |
+
Returns:
|
| 202 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 203 |
+
"""
|
| 204 |
+
index = 0
|
| 205 |
+
if os.path.isdir(vocab_path):
|
| 206 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 207 |
+
else:
|
| 208 |
+
vocab_file = vocab_path
|
| 209 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 210 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 211 |
+
if index != token_index:
|
| 212 |
+
logger.warning(
|
| 213 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 214 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 215 |
+
)
|
| 216 |
+
index = token_index
|
| 217 |
+
writer.write(token + "\n")
|
| 218 |
+
index += 1
|
| 219 |
+
return (vocab_file,)
|
| 220 |
+
|
| 221 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 222 |
+
r"""
|
| 223 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 224 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 225 |
+
should refer to the superclass for more information regarding methods.
|
| 226 |
+
Args:
|
| 227 |
+
vocab_file (:obj:`string`):
|
| 228 |
+
File containing the vocabulary.
|
| 229 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 230 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 231 |
+
token instead.
|
| 232 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 233 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 234 |
+
for sequence classification or for a text and a question for question answering.
|
| 235 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 236 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 237 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 238 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 239 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 240 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 241 |
+
special tokens.
|
| 242 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 243 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 244 |
+
modeling. This is the token which the model will try to predict.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
vocab_file,
|
| 250 |
+
unk_token="[UNK]",
|
| 251 |
+
sep_token="[SEP]",
|
| 252 |
+
pad_token="[PAD]",
|
| 253 |
+
cls_token="[CLS]",
|
| 254 |
+
mask_token="[MASK]",
|
| 255 |
+
**kwargs
|
| 256 |
+
):
|
| 257 |
+
super().__init__(
|
| 258 |
+
unk_token=unk_token,
|
| 259 |
+
sep_token=sep_token,
|
| 260 |
+
pad_token=pad_token,
|
| 261 |
+
cls_token=cls_token,
|
| 262 |
+
mask_token=mask_token,
|
| 263 |
+
**kwargs,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
if not os.path.isfile(vocab_file):
|
| 267 |
+
raise ValueError(
|
| 268 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 269 |
+
)
|
| 270 |
+
self.vocab = load_vocab(vocab_file)
|
| 271 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 272 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 273 |
+
|
| 274 |
+
@property
|
| 275 |
+
def vocab_size(self):
|
| 276 |
+
return len(self.vocab)
|
| 277 |
+
|
| 278 |
+
def get_vocab(self):
|
| 279 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 280 |
+
|
| 281 |
+
def _tokenize(self, text):
|
| 282 |
+
return self.tokenizer.tokenize(text)
|
| 283 |
+
|
| 284 |
+
def _convert_token_to_id(self, token):
|
| 285 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 286 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 287 |
+
|
| 288 |
+
def _convert_id_to_token(self, index):
|
| 289 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 290 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 291 |
+
|
| 292 |
+
def convert_tokens_to_string(self, tokens):
|
| 293 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 294 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 295 |
+
return out_string
|
| 296 |
+
|
| 297 |
+
def build_inputs_with_special_tokens(
|
| 298 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 299 |
+
) -> List[int]:
|
| 300 |
+
"""
|
| 301 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 302 |
+
by concatenating and adding special tokens.
|
| 303 |
+
A BERT sequence has the following format:
|
| 304 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 305 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 306 |
+
Args:
|
| 307 |
+
token_ids_0 (:obj:`List[int]`):
|
| 308 |
+
List of IDs to which the special tokens will be added
|
| 309 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 310 |
+
Optional second list of IDs for sequence pairs.
|
| 311 |
+
Returns:
|
| 312 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 313 |
+
"""
|
| 314 |
+
if token_ids_1 is None:
|
| 315 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 316 |
+
cls = [self.cls_token_id]
|
| 317 |
+
sep = [self.sep_token_id]
|
| 318 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 319 |
+
|
| 320 |
+
def get_special_tokens_mask(
|
| 321 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 322 |
+
) -> List[int]:
|
| 323 |
+
"""
|
| 324 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 325 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 326 |
+
Args:
|
| 327 |
+
token_ids_0 (:obj:`List[int]`):
|
| 328 |
+
List of ids.
|
| 329 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 330 |
+
Optional second list of IDs for sequence pairs.
|
| 331 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 332 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 333 |
+
Returns:
|
| 334 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
if already_has_special_tokens:
|
| 338 |
+
if token_ids_1 is not None:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 341 |
+
"ids is already formated with special tokens for the model."
|
| 342 |
+
)
|
| 343 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 344 |
+
|
| 345 |
+
if token_ids_1 is not None:
|
| 346 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 347 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 348 |
+
|
| 349 |
+
def create_token_type_ids_from_sequences(
|
| 350 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 351 |
+
) -> List[int]:
|
| 352 |
+
"""
|
| 353 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 354 |
+
A BERT sequence pair mask has the following format:
|
| 355 |
+
::
|
| 356 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 357 |
+
| first sequence | second sequence |
|
| 358 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 359 |
+
Args:
|
| 360 |
+
token_ids_0 (:obj:`List[int]`):
|
| 361 |
+
List of ids.
|
| 362 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 363 |
+
Optional second list of IDs for sequence pairs.
|
| 364 |
+
Returns:
|
| 365 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 366 |
+
sequence(s).
|
| 367 |
+
"""
|
| 368 |
+
sep = [self.sep_token_id]
|
| 369 |
+
cls = [self.cls_token_id]
|
| 370 |
+
if token_ids_1 is None:
|
| 371 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 372 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 373 |
+
|
| 374 |
+
def save_vocabulary(self, vocab_path):
|
| 375 |
+
"""
|
| 376 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 377 |
+
Args:
|
| 378 |
+
vocab_path (:obj:`str`):
|
| 379 |
+
The directory in which to save the vocabulary.
|
| 380 |
+
Returns:
|
| 381 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 382 |
+
"""
|
| 383 |
+
index = 0
|
| 384 |
+
if os.path.isdir(vocab_path):
|
| 385 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 386 |
+
else:
|
| 387 |
+
vocab_file = vocab_path
|
| 388 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 389 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 390 |
+
if index != token_index:
|
| 391 |
+
logger.warning(
|
| 392 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 393 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 394 |
+
)
|
| 395 |
+
index = token_index
|
| 396 |
+
writer.write(token + "\n")
|
| 397 |
+
index += 1
|
| 398 |
+
return (vocab_file,)
|
functions/tokenizer/new_splits.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
c 1
|
| 2 |
+
c 2
|
| 3 |
+
c 3
|
| 4 |
+
c 4
|
| 5 |
+
c 5
|
| 6 |
+
c 6
|
| 7 |
+
c 7
|
| 8 |
+
c 8
|
| 9 |
+
c 9
|
| 10 |
+
( c1
|
| 11 |
+
( c2
|
| 12 |
+
c1 )
|
| 13 |
+
c2 )
|
| 14 |
+
n 1
|
| 15 |
+
n 2
|
| 16 |
+
n 3
|
| 17 |
+
n 4
|
| 18 |
+
n 5
|
| 19 |
+
n 6
|
| 20 |
+
n 7
|
| 21 |
+
n 8
|
| 22 |
+
n 9
|
| 23 |
+
( n1
|
| 24 |
+
( n2
|
| 25 |
+
n1 )
|
| 26 |
+
n2 )
|
| 27 |
+
O 1
|
| 28 |
+
O 2
|
| 29 |
+
O 3
|
| 30 |
+
O 4
|
| 31 |
+
O 5
|
| 32 |
+
O 6
|
| 33 |
+
O 7
|
| 34 |
+
O 8
|
| 35 |
+
O 9
|
| 36 |
+
( O1
|
| 37 |
+
( O2
|
| 38 |
+
O2 )
|
| 39 |
+
O2 )
|
| 40 |
+
= O
|
| 41 |
+
= C
|
| 42 |
+
= c
|
| 43 |
+
= N
|
| 44 |
+
= n
|
| 45 |
+
=C C
|
| 46 |
+
=C N
|
| 47 |
+
=C c
|
| 48 |
+
=c c
|
| 49 |
+
=N C
|
| 50 |
+
=N c
|
| 51 |
+
=n C
|
| 52 |
+
=n c
|
| 53 |
+
# N
|
| 54 |
+
# C
|
| 55 |
+
#N C
|
| 56 |
+
#C C
|
| 57 |
+
#C N
|
| 58 |
+
#N N
|
| 59 |
+
( C
|
| 60 |
+
C )
|
| 61 |
+
( O
|
| 62 |
+
O )
|
| 63 |
+
( N
|
| 64 |
+
N )
|
| 65 |
+
Br c
|
| 66 |
+
( =O
|
| 67 |
+
(=O )
|
| 68 |
+
C (=O)
|
| 69 |
+
C =O
|
| 70 |
+
C =N
|
| 71 |
+
C #N
|
| 72 |
+
C #C
|
| 73 |
+
C C
|
| 74 |
+
CC C
|
| 75 |
+
CC N
|
| 76 |
+
CC O
|
| 77 |
+
CC S
|
| 78 |
+
CC c
|
| 79 |
+
CC n
|
| 80 |
+
C N
|
| 81 |
+
CN C
|
| 82 |
+
CN c
|
| 83 |
+
C O
|
| 84 |
+
CO C
|
| 85 |
+
CO N
|
| 86 |
+
CO c
|
| 87 |
+
C S
|
| 88 |
+
CS C
|
| 89 |
+
CS S
|
| 90 |
+
CS c
|
| 91 |
+
C c
|
| 92 |
+
Cl c
|
| 93 |
+
C n
|
| 94 |
+
F c
|
| 95 |
+
N C
|
| 96 |
+
NC C
|
| 97 |
+
NC c
|
| 98 |
+
N N
|
| 99 |
+
N O
|
| 100 |
+
N c
|
| 101 |
+
N n
|
| 102 |
+
O C
|
| 103 |
+
OC C
|
| 104 |
+
OC O
|
| 105 |
+
OC c
|
| 106 |
+
O N
|
| 107 |
+
O O
|
| 108 |
+
O c
|
| 109 |
+
S C
|
| 110 |
+
SC C
|
| 111 |
+
SC c
|
| 112 |
+
S S
|
| 113 |
+
S c
|
| 114 |
+
c c
|
| 115 |
+
cc c
|
| 116 |
+
cc n
|
| 117 |
+
cc o
|
| 118 |
+
cc s
|
| 119 |
+
cc cc
|
| 120 |
+
c n
|
| 121 |
+
cn c
|
| 122 |
+
cn n
|
| 123 |
+
c o
|
| 124 |
+
co c
|
| 125 |
+
c s
|
| 126 |
+
cs c
|
| 127 |
+
cs n
|
| 128 |
+
n c
|
| 129 |
+
nc c
|
| 130 |
+
nc n
|
| 131 |
+
nc o
|
| 132 |
+
nc s
|
| 133 |
+
n n
|
| 134 |
+
nn c
|
| 135 |
+
nn n
|
| 136 |
+
n o
|
| 137 |
+
no c
|
| 138 |
+
no n
|
| 139 |
+
n s
|
| 140 |
+
ns c
|
| 141 |
+
ns n
|
| 142 |
+
o c
|
| 143 |
+
oc c
|
| 144 |
+
o n
|
| 145 |
+
s c
|
| 146 |
+
sc c
|
| 147 |
+
sc n
|
| 148 |
+
s n
|
| 149 |
+
N P
|
| 150 |
+
P N
|
| 151 |
+
C P
|
| 152 |
+
P C
|
| 153 |
+
N S
|
| 154 |
+
S N
|
| 155 |
+
C S
|
| 156 |
+
S C
|
| 157 |
+
S P
|
| 158 |
+
P S
|
| 159 |
+
C I
|
functions/tokenizer/new_vocab.txt
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[SEP]
|
| 5 |
+
[MASK]
|
| 6 |
+
#
|
| 7 |
+
%
|
| 8 |
+
(
|
| 9 |
+
)
|
| 10 |
+
+
|
| 11 |
+
-
|
| 12 |
+
/
|
| 13 |
+
0
|
| 14 |
+
1
|
| 15 |
+
2
|
| 16 |
+
3
|
| 17 |
+
4
|
| 18 |
+
5
|
| 19 |
+
6
|
| 20 |
+
7
|
| 21 |
+
8
|
| 22 |
+
9
|
| 23 |
+
=
|
| 24 |
+
@
|
| 25 |
+
A
|
| 26 |
+
B
|
| 27 |
+
Br
|
| 28 |
+
Brc
|
| 29 |
+
C
|
| 30 |
+
CC
|
| 31 |
+
CCC
|
| 32 |
+
CCN
|
| 33 |
+
CCO
|
| 34 |
+
CCS
|
| 35 |
+
CCc
|
| 36 |
+
CCn
|
| 37 |
+
CN
|
| 38 |
+
CNC
|
| 39 |
+
CNc
|
| 40 |
+
CO
|
| 41 |
+
COC
|
| 42 |
+
CON
|
| 43 |
+
COc
|
| 44 |
+
CS
|
| 45 |
+
CSC
|
| 46 |
+
CSS
|
| 47 |
+
CSc
|
| 48 |
+
Cc
|
| 49 |
+
Cl
|
| 50 |
+
Clc
|
| 51 |
+
Cn
|
| 52 |
+
F
|
| 53 |
+
Fc
|
| 54 |
+
H
|
| 55 |
+
I
|
| 56 |
+
K
|
| 57 |
+
L
|
| 58 |
+
M
|
| 59 |
+
N
|
| 60 |
+
NC
|
| 61 |
+
NCC
|
| 62 |
+
NCc
|
| 63 |
+
NN
|
| 64 |
+
NO
|
| 65 |
+
Nc
|
| 66 |
+
Nn
|
| 67 |
+
O
|
| 68 |
+
OC
|
| 69 |
+
OCC
|
| 70 |
+
OCO
|
| 71 |
+
OCc
|
| 72 |
+
ON
|
| 73 |
+
OO
|
| 74 |
+
Oc
|
| 75 |
+
P
|
| 76 |
+
R
|
| 77 |
+
S
|
| 78 |
+
SC
|
| 79 |
+
SCC
|
| 80 |
+
SCc
|
| 81 |
+
SS
|
| 82 |
+
Sc
|
| 83 |
+
T
|
| 84 |
+
X
|
| 85 |
+
Z
|
| 86 |
+
[
|
| 87 |
+
\\
|
| 88 |
+
(/
|
| 89 |
+
]
|
| 90 |
+
a
|
| 91 |
+
b
|
| 92 |
+
c
|
| 93 |
+
cc
|
| 94 |
+
ccc
|
| 95 |
+
ccn
|
| 96 |
+
cco
|
| 97 |
+
ccs
|
| 98 |
+
cn
|
| 99 |
+
cnc
|
| 100 |
+
cnn
|
| 101 |
+
co
|
| 102 |
+
coc
|
| 103 |
+
cs
|
| 104 |
+
csc
|
| 105 |
+
csn
|
| 106 |
+
e
|
| 107 |
+
g
|
| 108 |
+
i
|
| 109 |
+
l
|
| 110 |
+
n
|
| 111 |
+
nc
|
| 112 |
+
ncc
|
| 113 |
+
ncn
|
| 114 |
+
nco
|
| 115 |
+
ncs
|
| 116 |
+
nn
|
| 117 |
+
nnc
|
| 118 |
+
nnn
|
| 119 |
+
no
|
| 120 |
+
noc
|
| 121 |
+
non
|
| 122 |
+
ns
|
| 123 |
+
nsc
|
| 124 |
+
nsn
|
| 125 |
+
o
|
| 126 |
+
oc
|
| 127 |
+
occ
|
| 128 |
+
on
|
| 129 |
+
p
|
| 130 |
+
r
|
| 131 |
+
s
|
| 132 |
+
sc
|
| 133 |
+
scc
|
| 134 |
+
scn
|
| 135 |
+
sn
|
| 136 |
+
t
|
| 137 |
+
c1
|
| 138 |
+
c2
|
| 139 |
+
c3
|
| 140 |
+
c4
|
| 141 |
+
c5
|
| 142 |
+
c6
|
| 143 |
+
c7
|
| 144 |
+
c8
|
| 145 |
+
c9
|
| 146 |
+
n1
|
| 147 |
+
n2
|
| 148 |
+
n3
|
| 149 |
+
n4
|
| 150 |
+
n5
|
| 151 |
+
n6
|
| 152 |
+
n7
|
| 153 |
+
n8
|
| 154 |
+
n9
|
| 155 |
+
O1
|
| 156 |
+
O2
|
| 157 |
+
O3
|
| 158 |
+
O4
|
| 159 |
+
O5
|
| 160 |
+
O6
|
| 161 |
+
O7
|
| 162 |
+
O8
|
| 163 |
+
O9
|
| 164 |
+
(c1
|
| 165 |
+
(c2
|
| 166 |
+
c1)
|
| 167 |
+
c2)
|
| 168 |
+
(n1
|
| 169 |
+
(n2
|
| 170 |
+
n1)
|
| 171 |
+
n2)
|
| 172 |
+
(O1
|
| 173 |
+
(O2
|
| 174 |
+
O2)
|
| 175 |
+
=O
|
| 176 |
+
=C
|
| 177 |
+
=c
|
| 178 |
+
=N
|
| 179 |
+
=n
|
| 180 |
+
=CC
|
| 181 |
+
=CN
|
| 182 |
+
=Cc
|
| 183 |
+
=cc
|
| 184 |
+
=NC
|
| 185 |
+
=Nc
|
| 186 |
+
=nC
|
| 187 |
+
=nc
|
| 188 |
+
#C
|
| 189 |
+
#CC
|
| 190 |
+
#CN
|
| 191 |
+
#N
|
| 192 |
+
#NC
|
| 193 |
+
#NN
|
| 194 |
+
(C
|
| 195 |
+
C)
|
| 196 |
+
(O
|
| 197 |
+
O)
|
| 198 |
+
(N
|
| 199 |
+
N)
|
| 200 |
+
NP
|
| 201 |
+
PN
|
| 202 |
+
CP
|
| 203 |
+
PC
|
| 204 |
+
NS
|
| 205 |
+
SN
|
| 206 |
+
SP
|
| 207 |
+
PS
|
| 208 |
+
C(=O)
|
| 209 |
+
(/Br)
|
| 210 |
+
(/C#N)
|
| 211 |
+
(/C)
|
| 212 |
+
(/C=N)
|
| 213 |
+
(/C=O)
|
| 214 |
+
(/CBr)
|
| 215 |
+
(/CC)
|
| 216 |
+
(/CCC)
|
| 217 |
+
(/CCF)
|
| 218 |
+
(/CCN)
|
| 219 |
+
(/CCO)
|
| 220 |
+
(/CCl)
|
| 221 |
+
(/CI)
|
| 222 |
+
(/CN)
|
| 223 |
+
(/CO)
|
| 224 |
+
(/CS)
|
| 225 |
+
(/Cl)
|
| 226 |
+
(/F)
|
| 227 |
+
(/I)
|
| 228 |
+
(/N)
|
| 229 |
+
(/NC)
|
| 230 |
+
(/NCC)
|
| 231 |
+
(/NO)
|
| 232 |
+
(/O)
|
| 233 |
+
(/OC)
|
| 234 |
+
(/OCC)
|
| 235 |
+
(/S)
|
| 236 |
+
(/SC)
|
| 237 |
+
(=C)
|
| 238 |
+
(=C/C)
|
| 239 |
+
(=C/F)
|
| 240 |
+
(=C/I)
|
| 241 |
+
(=C/N)
|
| 242 |
+
(=C/O)
|
| 243 |
+
(=CBr)
|
| 244 |
+
(=CC)
|
| 245 |
+
(=CCF)
|
| 246 |
+
(=CCN)
|
| 247 |
+
(=CCO)
|
| 248 |
+
(=CCl)
|
| 249 |
+
(=CF)
|
| 250 |
+
(=CI)
|
| 251 |
+
(=CN)
|
| 252 |
+
(=CO)
|
| 253 |
+
(=C\\C)
|
| 254 |
+
(=C\\F)
|
| 255 |
+
(=C\\I)
|
| 256 |
+
(=C\\N)
|
| 257 |
+
(=C\\O)
|
| 258 |
+
(=N)
|
| 259 |
+
(=N/C)
|
| 260 |
+
(=N/N)
|
| 261 |
+
(=N/O)
|
| 262 |
+
(=NBr)
|
| 263 |
+
(=NC)
|
| 264 |
+
(=NCC)
|
| 265 |
+
(=NCl)
|
| 266 |
+
(=NN)
|
| 267 |
+
(=NO)
|
| 268 |
+
(=NOC)
|
| 269 |
+
(=N\\C)
|
| 270 |
+
(=N\\N)
|
| 271 |
+
(=N\\O)
|
| 272 |
+
(=O)
|
| 273 |
+
(=S)
|
| 274 |
+
(B)
|
| 275 |
+
(Br)
|
| 276 |
+
(C#C)
|
| 277 |
+
(C#CC)
|
| 278 |
+
(C#CI)
|
| 279 |
+
(C#CO)
|
| 280 |
+
(C#N)
|
| 281 |
+
(C#SN)
|
| 282 |
+
(C)
|
| 283 |
+
(C=C)
|
| 284 |
+
(C=CF)
|
| 285 |
+
(C=CI)
|
| 286 |
+
(C=N)
|
| 287 |
+
(C=NN)
|
| 288 |
+
(C=NO)
|
| 289 |
+
(C=O)
|
| 290 |
+
(C=S)
|
| 291 |
+
(CBr)
|
| 292 |
+
(CC#C)
|
| 293 |
+
(CC#N)
|
| 294 |
+
(CC)
|
| 295 |
+
(CC=C)
|
| 296 |
+
(CC=O)
|
| 297 |
+
(CCBr)
|
| 298 |
+
(CCC)
|
| 299 |
+
(CCCC)
|
| 300 |
+
(CCCF)
|
| 301 |
+
(CCCI)
|
| 302 |
+
(CCCN)
|
| 303 |
+
(CCCO)
|
| 304 |
+
(CCCS)
|
| 305 |
+
(CCCl)
|
| 306 |
+
(CCF)
|
| 307 |
+
(CCI)
|
| 308 |
+
(CCN)
|
| 309 |
+
(CCNC)
|
| 310 |
+
(CCNN)
|
| 311 |
+
(CCNO)
|
| 312 |
+
(CCO)
|
| 313 |
+
(CCOC)
|
| 314 |
+
(CCON)
|
| 315 |
+
(CCS)
|
| 316 |
+
(CCSC)
|
| 317 |
+
(CCl)
|
| 318 |
+
(CF)
|
| 319 |
+
(CI)
|
| 320 |
+
(CN)
|
| 321 |
+
(CN=O)
|
| 322 |
+
(CNC)
|
| 323 |
+
(CNCC)
|
| 324 |
+
(CNCO)
|
| 325 |
+
(CNN)
|
| 326 |
+
(CNNC)
|
| 327 |
+
(CNO)
|
| 328 |
+
(CNOC)
|
| 329 |
+
(CO)
|
| 330 |
+
(COC)
|
| 331 |
+
(COCC)
|
| 332 |
+
(COCI)
|
| 333 |
+
(COCN)
|
| 334 |
+
(COCO)
|
| 335 |
+
(COF)
|
| 336 |
+
(CON)
|
| 337 |
+
(COO)
|
| 338 |
+
(CS)
|
| 339 |
+
(CSC)
|
| 340 |
+
(CSCC)
|
| 341 |
+
(CSCF)
|
| 342 |
+
(CSO)
|
| 343 |
+
(Cl)
|
| 344 |
+
(F)
|
| 345 |
+
(I)
|
| 346 |
+
(N)
|
| 347 |
+
(N=N)
|
| 348 |
+
(N=NO)
|
| 349 |
+
(N=O)
|
| 350 |
+
(N=S)
|
| 351 |
+
(NBr)
|
| 352 |
+
(NC#N)
|
| 353 |
+
(NC)
|
| 354 |
+
(NC=N)
|
| 355 |
+
(NC=O)
|
| 356 |
+
(NC=S)
|
| 357 |
+
(NCBr)
|
| 358 |
+
(NCC)
|
| 359 |
+
(NCCC)
|
| 360 |
+
(NCCF)
|
| 361 |
+
(NCCN)
|
| 362 |
+
(NCCO)
|
| 363 |
+
(NCCS)
|
| 364 |
+
(NCCl)
|
| 365 |
+
(NCNC)
|
| 366 |
+
(NCO)
|
| 367 |
+
(NCS)
|
| 368 |
+
(NCl)
|
| 369 |
+
(NN)
|
| 370 |
+
(NN=O)
|
| 371 |
+
(NNC)
|
| 372 |
+
(NO)
|
| 373 |
+
(NOC)
|
| 374 |
+
(O)
|
| 375 |
+
(OC#N)
|
| 376 |
+
(OC)
|
| 377 |
+
(OC=C)
|
| 378 |
+
(OC=O)
|
| 379 |
+
(OC=S)
|
| 380 |
+
(OCBr)
|
| 381 |
+
(OCC)
|
| 382 |
+
(OCCC)
|
| 383 |
+
(OCCF)
|
| 384 |
+
(OCCI)
|
| 385 |
+
(OCCN)
|
| 386 |
+
(OCCO)
|
| 387 |
+
(OCCS)
|
| 388 |
+
(OCCl)
|
| 389 |
+
(OCF)
|
| 390 |
+
(OCI)
|
| 391 |
+
(OCO)
|
| 392 |
+
(OCOC)
|
| 393 |
+
(OCON)
|
| 394 |
+
(OCSC)
|
| 395 |
+
(OCl)
|
| 396 |
+
(OI)
|
| 397 |
+
(ON)
|
| 398 |
+
(OO)
|
| 399 |
+
(OOC)
|
| 400 |
+
(OOCC)
|
| 401 |
+
(OOSN)
|
| 402 |
+
(OSC)
|
| 403 |
+
(P)
|
| 404 |
+
(S)
|
| 405 |
+
(SC#N)
|
| 406 |
+
(SC)
|
| 407 |
+
(SCC)
|
| 408 |
+
(SCCC)
|
| 409 |
+
(SCCF)
|
| 410 |
+
(SCCN)
|
| 411 |
+
(SCCO)
|
| 412 |
+
(SCCS)
|
| 413 |
+
(SCCl)
|
| 414 |
+
(SCF)
|
| 415 |
+
(SCN)
|
| 416 |
+
(SCOC)
|
| 417 |
+
(SCSC)
|
| 418 |
+
(SCl)
|
| 419 |
+
(SI)
|
| 420 |
+
(SN)
|
| 421 |
+
(SN=O)
|
| 422 |
+
(SO)
|
| 423 |
+
(SOC)
|
| 424 |
+
(SOOO)
|
| 425 |
+
(SS)
|
| 426 |
+
(SSC)
|
| 427 |
+
(SSCC)
|
| 428 |
+
([At])
|
| 429 |
+
([O-])
|
| 430 |
+
([O])
|
| 431 |
+
([S-])
|
| 432 |
+
(\\Br)
|
| 433 |
+
(\\C#N)
|
| 434 |
+
(\\C)
|
| 435 |
+
(\\C=N)
|
| 436 |
+
(\\C=O)
|
| 437 |
+
(\\CBr)
|
| 438 |
+
(\\CC)
|
| 439 |
+
(\\CCC)
|
| 440 |
+
(\\CCO)
|
| 441 |
+
(\\CCl)
|
| 442 |
+
(\\CF)
|
| 443 |
+
(\\CN)
|
| 444 |
+
(\\CNC)
|
| 445 |
+
(\\CO)
|
| 446 |
+
(\\COC)
|
| 447 |
+
(\\Cl)
|
| 448 |
+
(\\F)
|
| 449 |
+
(\\I)
|
| 450 |
+
(\\N)
|
| 451 |
+
(\\NC)
|
| 452 |
+
(\\NCC)
|
| 453 |
+
(\\NN)
|
| 454 |
+
(\\NO)
|
| 455 |
+
(\\NOC)
|
| 456 |
+
(\\O)
|
| 457 |
+
(\\OC)
|
| 458 |
+
(\\OCC)
|
| 459 |
+
(\\ON)
|
| 460 |
+
(\\S)
|
| 461 |
+
(\\SC)
|
| 462 |
+
(\\SCC)
|
| 463 |
+
[Ag+]
|
| 464 |
+
[Ag-4]
|
| 465 |
+
[Ag]
|
| 466 |
+
[Al-3]
|
| 467 |
+
[Al]
|
| 468 |
+
[As+]
|
| 469 |
+
[AsH3]
|
| 470 |
+
[AsH]
|
| 471 |
+
[As]
|
| 472 |
+
[At]
|
| 473 |
+
[B-]
|
| 474 |
+
[B@-]
|
| 475 |
+
[B@@-]
|
| 476 |
+
[BH-]
|
| 477 |
+
[BH2-]
|
| 478 |
+
[BH3-]
|
| 479 |
+
[B]
|
| 480 |
+
[Ba]
|
| 481 |
+
[Br+2]
|
| 482 |
+
[BrH]
|
| 483 |
+
[Br]
|
| 484 |
+
[C+]
|
| 485 |
+
[C-]
|
| 486 |
+
[C@@H]
|
| 487 |
+
[C@@]
|
| 488 |
+
[C@H]
|
| 489 |
+
[C@]
|
| 490 |
+
[CH-]
|
| 491 |
+
[CH2]
|
| 492 |
+
[CH3]
|
| 493 |
+
[CH]
|
| 494 |
+
[C]
|
| 495 |
+
[CaH2]
|
| 496 |
+
[Ca]
|
| 497 |
+
[Cl+2]
|
| 498 |
+
[Cl+3]
|
| 499 |
+
[Cl+]
|
| 500 |
+
[Cs]
|
| 501 |
+
[FH]
|
| 502 |
+
[F]
|
| 503 |
+
[H]
|
| 504 |
+
[He]
|
| 505 |
+
[I+2]
|
| 506 |
+
[I+3]
|
| 507 |
+
[I+]
|
| 508 |
+
[IH]
|
| 509 |
+
[I]
|
| 510 |
+
[K]
|
| 511 |
+
[Kr]
|
| 512 |
+
[Li+]
|
| 513 |
+
[LiH]
|
| 514 |
+
[MgH2]
|
| 515 |
+
[Mg]
|
| 516 |
+
[N+]
|
| 517 |
+
[N-]
|
| 518 |
+
[N@+]
|
| 519 |
+
[N@@+]
|
| 520 |
+
[N@@]
|
| 521 |
+
[N@]
|
| 522 |
+
[NH+]
|
| 523 |
+
[NH-]
|
| 524 |
+
[NH2+]
|
| 525 |
+
[NH3]
|
| 526 |
+
[NH]
|
| 527 |
+
[N]
|
| 528 |
+
[Na]
|
| 529 |
+
[O+]
|
| 530 |
+
[O-]
|
| 531 |
+
[OH+]
|
| 532 |
+
[OH2]
|
| 533 |
+
[OH]
|
| 534 |
+
[O]
|
| 535 |
+
[P+]
|
| 536 |
+
[P@+]
|
| 537 |
+
[P@@+]
|
| 538 |
+
[P@@]
|
| 539 |
+
[P@]
|
| 540 |
+
[PH2]
|
| 541 |
+
[PH]
|
| 542 |
+
[P]
|
| 543 |
+
[Ra]
|
| 544 |
+
[Rb]
|
| 545 |
+
[S+]
|
| 546 |
+
[S-]
|
| 547 |
+
[S@+]
|
| 548 |
+
[S@@+]
|
| 549 |
+
[S@@]
|
| 550 |
+
[S@]
|
| 551 |
+
[SH+]
|
| 552 |
+
[SH2]
|
| 553 |
+
[SH]
|
| 554 |
+
[S]
|
| 555 |
+
[Se+]
|
| 556 |
+
[Se-2]
|
| 557 |
+
[SeH2]
|
| 558 |
+
[SeH]
|
| 559 |
+
[Se]
|
| 560 |
+
[Si@]
|
| 561 |
+
[SiH2]
|
| 562 |
+
[SiH]
|
| 563 |
+
[Si]
|
| 564 |
+
[SrH2]
|
| 565 |
+
[TeH]
|
| 566 |
+
[Te]
|
| 567 |
+
[Xe]
|
| 568 |
+
[Zn+2]
|
| 569 |
+
[Zn-2]
|
| 570 |
+
[Zn]
|
| 571 |
+
[b-]
|
| 572 |
+
[c+]
|
| 573 |
+
[c-]
|
| 574 |
+
[cH-]
|
| 575 |
+
[cH]
|
| 576 |
+
[c]
|
| 577 |
+
[n+]
|
| 578 |
+
[n-]
|
| 579 |
+
[nH]
|
| 580 |
+
[n]
|
| 581 |
+
[o+]
|
| 582 |
+
[s+]
|
| 583 |
+
[se+]
|
| 584 |
+
[se]
|
| 585 |
+
[te+]
|
| 586 |
+
[te]
|
scoring_functions.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import io
|
| 3 |
+
import subprocess
|
| 4 |
+
import warnings
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from typing import List
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 11 |
+
from rdkit.Chem import AllChem
|
| 12 |
+
import torch
|
| 13 |
+
from functions.binding.binding import BindingAffinity
|
| 14 |
+
from functions.permeability.permeability import Permeability
|
| 15 |
+
from functions.solubility.solubility import Solubility
|
| 16 |
+
from functions.hemolysis.hemolysis import Hemolysis
|
| 17 |
+
from functions.nonfouling.nonfouling import Nonfouling
|
| 18 |
+
|
| 19 |
+
class ScoringFunctions:
|
| 20 |
+
def __init__(self, score_func_names=None, prot_seqs=[]):
|
| 21 |
+
"""
|
| 22 |
+
Class for generating score vectors given generated sequence
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
score_func_names: list of scoring function names to be evaluated
|
| 26 |
+
score_weights: weights to scale scores (default: 1)
|
| 27 |
+
target_protein: sequence of target protein binder
|
| 28 |
+
"""
|
| 29 |
+
if score_func_names is None:
|
| 30 |
+
# just do unmasking based on validity of peptide bonds
|
| 31 |
+
self.score_func_names = []
|
| 32 |
+
else:
|
| 33 |
+
self.score_func_names = score_func_names
|
| 34 |
+
|
| 35 |
+
# self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
|
| 36 |
+
|
| 37 |
+
# binding affinities
|
| 38 |
+
self.target_protein = prot_seqs
|
| 39 |
+
print(len(prot_seqs))
|
| 40 |
+
|
| 41 |
+
if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
|
| 42 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0])
|
| 43 |
+
binding_affinity2 = None
|
| 44 |
+
elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
|
| 45 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0])
|
| 46 |
+
binding_affinity2 = BindingAffinity(prot_seqs[1])
|
| 47 |
+
else:
|
| 48 |
+
print("here")
|
| 49 |
+
binding_affinity1 = None
|
| 50 |
+
binding_affinity2 = None
|
| 51 |
+
|
| 52 |
+
permeability = Permeability()
|
| 53 |
+
sol = Solubility()
|
| 54 |
+
nonfouling = Nonfouling()
|
| 55 |
+
hemo = Hemolysis()
|
| 56 |
+
|
| 57 |
+
self.all_funcs = {'binding_affinity1': binding_affinity1,
|
| 58 |
+
'binding_affinity2': binding_affinity2,
|
| 59 |
+
'permeability': permeability,
|
| 60 |
+
'nonfouling': nonfouling,
|
| 61 |
+
'solubility': sol,
|
| 62 |
+
'hemolysis': hemo
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def forward(self, input_seqs):
|
| 66 |
+
scores = []
|
| 67 |
+
|
| 68 |
+
for i, score_func in enumerate(self.score_func_names):
|
| 69 |
+
score = self.all_funcs[score_func](input_seqs = input_seqs)
|
| 70 |
+
|
| 71 |
+
scores.append(score)
|
| 72 |
+
|
| 73 |
+
# convert to numpy arrays with shape (num_sequences, num_functions)
|
| 74 |
+
scores = np.float32(scores).T
|
| 75 |
+
|
| 76 |
+
return scores
|
| 77 |
+
|
| 78 |
+
def __call__(self, input_seqs: list):
|
| 79 |
+
return self.forward(input_seqs)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def unittest():
|
| 83 |
+
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
|
| 84 |
+
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
|
| 85 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 86 |
+
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
|
| 87 |
+
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
|
| 88 |
+
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
|
| 89 |
+
cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
|
| 90 |
+
|
| 91 |
+
num_iter = 0
|
| 92 |
+
score_func_times = [0, 1, 2, 3, 4, 5]
|
| 93 |
+
|
| 94 |
+
scoring = ScoringFunctions(score_func_names=['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'], prot_seqs=[tfr])
|
| 95 |
+
|
| 96 |
+
smiles = ['N2[C@H](CC(C)C)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](Cc1ccccc1C(F)(F)F)C(=O)N1[C@@H](CCC1)C(=O)N[C@@H](CCSC)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](CC(=O)N)C2(=O)']
|
| 97 |
+
|
| 98 |
+
scores = scoring(input_seqs=smiles)
|
| 99 |
+
print(scores)
|
| 100 |
+
print(len(scores))
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
unittest()
|
train/binary_xg.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from sklearn.metrics import precision_recall_curve, f1_score
|
| 6 |
+
import optuna
|
| 7 |
+
from optuna.trial import TrialState
|
| 8 |
+
import xgboost as xgb
|
| 9 |
+
import os
|
| 10 |
+
from datasets import load_from_disk
|
| 11 |
+
from lightning.pytorch import seed_everything
|
| 12 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 13 |
+
from typing import List
|
| 14 |
+
from rdkit.Chem import AllChem
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
| 17 |
+
|
| 18 |
+
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 19 |
+
|
| 20 |
+
def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path):
|
| 21 |
+
"""
|
| 22 |
+
Saves the true and predicted values for training and validation sets, and generates binary classification plots.
|
| 23 |
+
|
| 24 |
+
Parameters:
|
| 25 |
+
y_true_train (array): True labels for the training set.
|
| 26 |
+
y_pred_train (array): Predicted probabilities for the training set.
|
| 27 |
+
y_true_val (array): True labels for the validation set.
|
| 28 |
+
y_pred_val (array): Predicted probabilities for the validation set.
|
| 29 |
+
threshold (float): Classification threshold for predictions.
|
| 30 |
+
output_path (str): Directory to save the CSV files and plots.
|
| 31 |
+
"""
|
| 32 |
+
os.makedirs(output_path, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
# Convert probabilities to binary predictions
|
| 35 |
+
y_pred_train_binary = (y_pred_train >= threshold).astype(int)
|
| 36 |
+
y_pred_val_binary = (y_pred_val >= threshold).astype(int)
|
| 37 |
+
|
| 38 |
+
# Save training predictions
|
| 39 |
+
train_df = pd.DataFrame({
|
| 40 |
+
'True Label': y_true_train,
|
| 41 |
+
'Predicted Probability': y_pred_train,
|
| 42 |
+
'Predicted Label': y_pred_train_binary
|
| 43 |
+
})
|
| 44 |
+
train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False)
|
| 45 |
+
|
| 46 |
+
# Save validation predictions
|
| 47 |
+
val_df = pd.DataFrame({
|
| 48 |
+
'True Label': y_true_val,
|
| 49 |
+
'Predicted Probability': y_pred_val,
|
| 50 |
+
'Predicted Label': y_pred_val_binary
|
| 51 |
+
})
|
| 52 |
+
val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False)
|
| 53 |
+
|
| 54 |
+
# Plot training predictions
|
| 55 |
+
plot_binary_correlation(
|
| 56 |
+
y_true_train,
|
| 57 |
+
y_pred_train,
|
| 58 |
+
threshold,
|
| 59 |
+
title="Training Set Binary Classification Plot",
|
| 60 |
+
output_file=os.path.join(output_path, 'train_classification_plot.png')
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Plot validation predictions
|
| 64 |
+
plot_binary_correlation(
|
| 65 |
+
y_true_val,
|
| 66 |
+
y_pred_val,
|
| 67 |
+
threshold,
|
| 68 |
+
title="Validation Set Binary Classification Plot",
|
| 69 |
+
output_file=os.path.join(output_path, 'val_classification_plot.png')
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def plot_binary_correlation(y_true, y_pred, threshold, title, output_file):
|
| 73 |
+
"""
|
| 74 |
+
Generates a scatter plot for binary classification and saves it to a file.
|
| 75 |
+
|
| 76 |
+
Parameters:
|
| 77 |
+
y_true (array): True labels.
|
| 78 |
+
y_pred (array): Predicted probabilities.
|
| 79 |
+
threshold (float): Classification threshold for predictions.
|
| 80 |
+
title (str): Title of the plot.
|
| 81 |
+
output_file (str): Path to save the plot.
|
| 82 |
+
"""
|
| 83 |
+
# Scatter plot
|
| 84 |
+
plt.figure(figsize=(10, 8))
|
| 85 |
+
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
|
| 86 |
+
|
| 87 |
+
# Add threshold line
|
| 88 |
+
plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
|
| 89 |
+
|
| 90 |
+
# Add annotations
|
| 91 |
+
plt.title(title)
|
| 92 |
+
plt.xlabel("True Labels")
|
| 93 |
+
plt.ylabel("Predicted Probability")
|
| 94 |
+
plt.legend()
|
| 95 |
+
|
| 96 |
+
# Save and show the plot
|
| 97 |
+
plt.tight_layout()
|
| 98 |
+
plt.savefig(output_file)
|
| 99 |
+
plt.show()
|
| 100 |
+
|
| 101 |
+
seed_everything(42)
|
| 102 |
+
|
| 103 |
+
dataset = load_from_disk(f'{base_path}/data/solubility')
|
| 104 |
+
|
| 105 |
+
sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings
|
| 106 |
+
labels = np.stack(dataset['labels'])
|
| 107 |
+
embeddings = np.stack(dataset['embedding'])
|
| 108 |
+
|
| 109 |
+
# Initialize best F1 score and model path
|
| 110 |
+
best_f1 = -np.inf
|
| 111 |
+
best_model_path = f"{base_path}/src/solubility"
|
| 112 |
+
|
| 113 |
+
# Trial callback
|
| 114 |
+
def trial_info_callback(study, trial):
|
| 115 |
+
if study.best_trial == trial:
|
| 116 |
+
print(f"Trial {trial.number}:")
|
| 117 |
+
print(f" Weighted F1 Score: {trial.value}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def objective(trial):
|
| 122 |
+
# Define hyperparameters
|
| 123 |
+
params = {
|
| 124 |
+
'objective': 'binary:logistic',
|
| 125 |
+
'lambda': trial.suggest_float('lambda', 1e-8, 50.0, log=True),
|
| 126 |
+
'alpha': trial.suggest_float('alpha', 1e-8, 50.0, log=True),
|
| 127 |
+
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0),
|
| 128 |
+
'subsample': trial.suggest_float('subsample', 0.5, 1.0),
|
| 129 |
+
'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.3),
|
| 130 |
+
'max_depth': trial.suggest_int('max_depth', 2, 15),
|
| 131 |
+
'min_child_weight': trial.suggest_int('min_child_weight', 1, 500),
|
| 132 |
+
'gamma': trial.suggest_float('gamma', 0, 10.0),
|
| 133 |
+
'tree_method': 'hist',
|
| 134 |
+
'device': 'cuda:6',
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
# Suggest number of boosting rounds
|
| 138 |
+
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
|
| 139 |
+
threshold = 0.5 # Initial classification threshold
|
| 140 |
+
|
| 141 |
+
# Split the data
|
| 142 |
+
train_idx, val_idx = train_test_split(
|
| 143 |
+
np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42
|
| 144 |
+
)
|
| 145 |
+
train_subset = dataset.select(train_idx).with_format("torch")
|
| 146 |
+
val_subset = dataset.select(val_idx).with_format("torch")
|
| 147 |
+
|
| 148 |
+
# Extract embeddings and labels for train/validation
|
| 149 |
+
train_embeddings = np.array(train_subset['embedding'])
|
| 150 |
+
valid_embeddings = np.array(val_subset['embedding'])
|
| 151 |
+
train_labels = np.array(train_subset['labels'])
|
| 152 |
+
valid_labels = np.array(val_subset['labels'])
|
| 153 |
+
|
| 154 |
+
# Prepare training and validation sets
|
| 155 |
+
dtrain = xgb.DMatrix(train_embeddings, label=train_labels)
|
| 156 |
+
dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels)
|
| 157 |
+
|
| 158 |
+
# Train the model
|
| 159 |
+
model = xgb.train(
|
| 160 |
+
params=params,
|
| 161 |
+
dtrain=dtrain,
|
| 162 |
+
num_boost_round=num_boost_round,
|
| 163 |
+
evals=[(dvalid, "validation")],
|
| 164 |
+
early_stopping_rounds=50,
|
| 165 |
+
verbose_eval=False,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Predict probabilities
|
| 169 |
+
preds_train = model.predict(dtrain)
|
| 170 |
+
preds_val = model.predict(dvalid)
|
| 171 |
+
|
| 172 |
+
# Calculate metrics
|
| 173 |
+
f1_val = f1_score(valid_labels, (preds_val >= threshold).astype(int), average="weighted")
|
| 174 |
+
auc_val = roc_auc_score(valid_labels, preds_val)
|
| 175 |
+
print(f"Trial {trial.number}: AUC: {auc_val:.3f}, F1 Score: {f1_val:.3f}")
|
| 176 |
+
|
| 177 |
+
# Save the model if it has the best F1 score
|
| 178 |
+
current_best = trial.study.user_attrs.get("best_f1", -np.inf)
|
| 179 |
+
if f1_val > current_best:
|
| 180 |
+
trial.study.set_user_attr("best_f1", f1_val)
|
| 181 |
+
trial.study.set_user_attr("best_auc", auc_val)
|
| 182 |
+
trial.study.set_user_attr("best_trial", trial.number)
|
| 183 |
+
os.makedirs(best_model_path, exist_ok=True)
|
| 184 |
+
|
| 185 |
+
# Save the model
|
| 186 |
+
model.save_model(os.path.join(best_model_path, "best_model_f1.json"))
|
| 187 |
+
print(f"✓ NEW BEST! Trial {trial.number}: F1={f1_val:.4f}, AUC={auc_val:.4f} - Model saved!")
|
| 188 |
+
|
| 189 |
+
# Save and plot binary predictions
|
| 190 |
+
save_and_plot_binary_predictions(
|
| 191 |
+
train_labels, preds_train, valid_labels, preds_val, threshold, best_model_path
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
return f1_val
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 198 |
+
study.optimize(objective, n_trials=200)
|
| 199 |
+
|
| 200 |
+
# Prepare summary text
|
| 201 |
+
summary = []
|
| 202 |
+
summary.append("\n" + "="*60)
|
| 203 |
+
summary.append("OPTIMIZATION COMPLETE")
|
| 204 |
+
summary.append("="*60)
|
| 205 |
+
summary.append(f"Number of finished trials: {len(study.trials)}")
|
| 206 |
+
summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
|
| 207 |
+
summary.append(f"Best F1 Score: {study.user_attrs.get('best_f1', None):.4f}")
|
| 208 |
+
summary.append(f"Best AUC Score: {study.user_attrs.get('best_auc', None):.4f}")
|
| 209 |
+
summary.append(f"Optuna Best Trial Value: {study.best_trial.value:.4f}")
|
| 210 |
+
summary.append(f"\nBest hyperparameters:")
|
| 211 |
+
for key, value in study.best_trial.params.items():
|
| 212 |
+
summary.append(f" {key}: {value}")
|
| 213 |
+
summary.append("="*60)
|
| 214 |
+
|
| 215 |
+
# Print to console
|
| 216 |
+
for line in summary:
|
| 217 |
+
print(line)
|
| 218 |
+
|
| 219 |
+
# Save to file
|
| 220 |
+
metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
|
| 221 |
+
with open(metrics_file, 'w') as f:
|
| 222 |
+
f.write('\n'.join(summary))
|
| 223 |
+
print(f"\n✓ Metrics saved to: {metrics_file}")
|
train/permeability_xg.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import optuna
|
| 4 |
+
from optuna.trial import TrialState
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
from rdkit.Chem import AllChem
|
| 7 |
+
from sklearn.metrics import mean_squared_error
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
import xgboost as xgb
|
| 10 |
+
import os
|
| 11 |
+
from datasets import load_from_disk
|
| 12 |
+
from scipy.stats import spearmanr
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
|
| 15 |
+
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 16 |
+
|
| 17 |
+
def save_and_plot_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, output_path):
|
| 18 |
+
os.makedirs(output_path, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
# Save training predictions
|
| 21 |
+
train_df = pd.DataFrame({'True Permeability': y_true_train, 'Predicted Permeability': y_pred_train})
|
| 22 |
+
train_df.to_csv(os.path.join(output_path, 'train_predictions.csv'), index=False)
|
| 23 |
+
|
| 24 |
+
# Save validation predictions
|
| 25 |
+
val_df = pd.DataFrame({'True Permeability': y_true_val, 'Predicted Permeability': y_pred_val})
|
| 26 |
+
val_df.to_csv(os.path.join(output_path, 'val_predictions.csv'), index=False)
|
| 27 |
+
|
| 28 |
+
# Plot training predictions
|
| 29 |
+
plot_correlation(
|
| 30 |
+
y_true_train,
|
| 31 |
+
y_pred_train,
|
| 32 |
+
title="Training Set Correlation Plot",
|
| 33 |
+
output_file=os.path.join(output_path, 'train_correlation.png'),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Plot validation predictions
|
| 37 |
+
plot_correlation(
|
| 38 |
+
y_true_val,
|
| 39 |
+
y_pred_val,
|
| 40 |
+
title="Validation Set Correlation Plot",
|
| 41 |
+
output_file=os.path.join(output_path, 'val_correlation.png'),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def plot_correlation(y_true, y_pred, title, output_file):
|
| 45 |
+
spearman_corr, _ = spearmanr(y_true, y_pred)
|
| 46 |
+
|
| 47 |
+
# Scatter plot
|
| 48 |
+
plt.figure(figsize=(10, 8))
|
| 49 |
+
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
|
| 50 |
+
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='teal', linestyle='--', label='Ideal fit')
|
| 51 |
+
|
| 52 |
+
# Add annotations
|
| 53 |
+
plt.title(f"{title}\nSpearman Correlation: {spearman_corr:.3f}")
|
| 54 |
+
plt.xlabel("True Permeability (logP)")
|
| 55 |
+
plt.ylabel("Predicted Affinity (logP)")
|
| 56 |
+
plt.legend()
|
| 57 |
+
|
| 58 |
+
# Save and show the plot
|
| 59 |
+
plt.tight_layout()
|
| 60 |
+
plt.savefig(output_file)
|
| 61 |
+
plt.show()
|
| 62 |
+
|
| 63 |
+
# Load dataset
|
| 64 |
+
dataset = load_from_disk(f'{base_path}/data/permeability')
|
| 65 |
+
|
| 66 |
+
# Extract sequences, labels, and embeddings
|
| 67 |
+
sequences = np.stack(dataset['sequence'])
|
| 68 |
+
labels = np.stack(dataset['labels']) # Regression labels
|
| 69 |
+
embeddings = np.stack(dataset['embedding']) # Pre-trained embeddings
|
| 70 |
+
|
| 71 |
+
# Function to compute Morgan fingerprints
|
| 72 |
+
def compute_morgan_fingerprints(smiles_list, radius=2, n_bits=2048):
|
| 73 |
+
fps = []
|
| 74 |
+
for smiles in smiles_list:
|
| 75 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 76 |
+
if mol is not None:
|
| 77 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
|
| 78 |
+
fps.append(np.array(fp))
|
| 79 |
+
else:
|
| 80 |
+
# If the SMILES string is invalid, use a zero vector
|
| 81 |
+
fps.append(np.zeros(n_bits))
|
| 82 |
+
print(f"Invalid SMILES: {smiles}")
|
| 83 |
+
return np.array(fps)
|
| 84 |
+
|
| 85 |
+
# Compute Morgan fingerprints for the sequences
|
| 86 |
+
#morgan_fingerprints = compute_morgan_fingerprints(sequences)
|
| 87 |
+
|
| 88 |
+
# Concatenate embeddings with Morgan fingerprints
|
| 89 |
+
#input_features = np.concatenate([embeddings, morgan_fingerprints], axis=1)
|
| 90 |
+
input_features = embeddings
|
| 91 |
+
|
| 92 |
+
# Initialize global variables
|
| 93 |
+
best_model_path = f"{base_path}/src/permeability"
|
| 94 |
+
os.makedirs(best_model_path, exist_ok=True)
|
| 95 |
+
|
| 96 |
+
def trial_info_callback(study, trial):
|
| 97 |
+
if study.best_trial == trial:
|
| 98 |
+
print(f"Trial {trial.number}:")
|
| 99 |
+
print(f" MSE: {trial.value}")
|
| 100 |
+
|
| 101 |
+
def objective(trial):
|
| 102 |
+
# Define hyperparameters
|
| 103 |
+
params = {
|
| 104 |
+
'objective': 'reg:squarederror',
|
| 105 |
+
'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True),
|
| 106 |
+
'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True),
|
| 107 |
+
'gamma': trial.suggest_float('gamma', 0, 5),
|
| 108 |
+
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
|
| 109 |
+
'subsample': trial.suggest_float('subsample', 0.6, 0.9),
|
| 110 |
+
'learning_rate': trial.suggest_float('learning_rate', 1e-5, 0.1),
|
| 111 |
+
'max_depth': trial.suggest_int('max_depth', 2, 30),
|
| 112 |
+
'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
|
| 113 |
+
'tree_method': 'hist',
|
| 114 |
+
'scale_pos_weight': trial.suggest_float('scale_pos_weight', 0.5, 10.0, log=True),
|
| 115 |
+
'device': 'cuda:6',
|
| 116 |
+
}
|
| 117 |
+
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
|
| 118 |
+
|
| 119 |
+
# Train-validation split
|
| 120 |
+
X_train, X_val, y_train, y_val = train_test_split(input_features, labels, test_size=0.2, random_state=42)
|
| 121 |
+
|
| 122 |
+
# Convert data to DMatrix
|
| 123 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 124 |
+
dvalid = xgb.DMatrix(X_val, label=y_val)
|
| 125 |
+
|
| 126 |
+
# Train XGBoost
|
| 127 |
+
model = xgb.train(
|
| 128 |
+
params=params,
|
| 129 |
+
dtrain=dtrain,
|
| 130 |
+
num_boost_round=num_boost_round,
|
| 131 |
+
evals=[(dvalid, "validation")],
|
| 132 |
+
early_stopping_rounds=50,
|
| 133 |
+
verbose_eval=False,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Predict and evaluate
|
| 137 |
+
preds_train = model.predict(dtrain)
|
| 138 |
+
preds_val = model.predict(dvalid)
|
| 139 |
+
|
| 140 |
+
mse = mean_squared_error(y_val, preds_val)
|
| 141 |
+
|
| 142 |
+
# Calculate Spearman Rank Correlation for both train and validation
|
| 143 |
+
spearman_train, _ = spearmanr(y_train, preds_train)
|
| 144 |
+
spearman_val, _ = spearmanr(y_val, preds_val)
|
| 145 |
+
print(f"Train Spearman: {spearman_train:.4f}, Val Spearman: {spearman_val:.4f}")
|
| 146 |
+
|
| 147 |
+
# Save the best model
|
| 148 |
+
if trial.study.user_attrs.get("best_mse", np.inf) > mse:
|
| 149 |
+
trial.study.set_user_attr("best_mse", mse)
|
| 150 |
+
trial.study.set_user_attr("best_spearman_train", spearman_train)
|
| 151 |
+
trial.study.set_user_attr("best_spearman_val", spearman_val)
|
| 152 |
+
trial.study.set_user_attr("best_trial", trial.number)
|
| 153 |
+
model.save_model(os.path.join(best_model_path, "best_model.json"))
|
| 154 |
+
save_and_plot_predictions(y_train, preds_train, y_val, preds_val, best_model_path)
|
| 155 |
+
print(f"✓ NEW BEST! Trial {trial.number}: MSE={mse:.4f}, Train Spearman={spearman_train:.4f}, Val Spearman={spearman_val:.4f}")
|
| 156 |
+
|
| 157 |
+
return mse
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
|
| 161 |
+
study.optimize(objective, n_trials=200, callbacks=[trial_info_callback])
|
| 162 |
+
|
| 163 |
+
# Prepare summary text
|
| 164 |
+
summary = []
|
| 165 |
+
summary.append("\n" + "="*60)
|
| 166 |
+
summary.append("OPTIMIZATION COMPLETE")
|
| 167 |
+
summary.append("="*60)
|
| 168 |
+
summary.append(f"Number of finished trials: {len(study.trials)}")
|
| 169 |
+
summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
|
| 170 |
+
summary.append(f"Best MSE: {study.best_trial.value:.4f}")
|
| 171 |
+
summary.append(f"Best Training Spearman Correlation: {study.user_attrs.get('best_spearman_train', None):.4f}")
|
| 172 |
+
summary.append(f"Best Validation Spearman Correlation: {study.user_attrs.get('best_spearman_val', None):.4f}")
|
| 173 |
+
summary.append(f"\nBest hyperparameters:")
|
| 174 |
+
for key, value in study.best_trial.params.items():
|
| 175 |
+
summary.append(f" {key}: {value}")
|
| 176 |
+
summary.append("="*60)
|
| 177 |
+
|
| 178 |
+
# Print to console
|
| 179 |
+
for line in summary:
|
| 180 |
+
print(line)
|
| 181 |
+
|
| 182 |
+
# Save to file
|
| 183 |
+
metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
|
| 184 |
+
with open(metrics_file, 'w') as f:
|
| 185 |
+
f.write('\n'.join(summary))
|
| 186 |
+
print(f"\n✓ Metrics saved to: {metrics_file}")
|