|
|
|
|
|
import torch |
|
|
import pandas as pd |
|
|
import torch.nn as nn |
|
|
import esm |
|
|
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModel |
|
|
|
|
|
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse" |
|
|
|
|
|
class ImprovedBindingPredictor(nn.Module): |
|
|
def __init__(self, |
|
|
esm_dim=1280, |
|
|
smiles_dim=768, |
|
|
hidden_dim=512, |
|
|
n_heads=8, |
|
|
n_layers=3, |
|
|
dropout=0.1): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.tight_threshold = 7.5 |
|
|
self.weak_threshold = 6.0 |
|
|
|
|
|
|
|
|
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) |
|
|
self.protein_projection = nn.Linear(esm_dim, hidden_dim) |
|
|
self.protein_norm = nn.LayerNorm(hidden_dim) |
|
|
self.smiles_norm = nn.LayerNorm(hidden_dim) |
|
|
|
|
|
|
|
|
self.cross_attention_layers = nn.ModuleList([ |
|
|
nn.ModuleDict({ |
|
|
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), |
|
|
'norm1': nn.LayerNorm(hidden_dim), |
|
|
'ffn': nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim * 4), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim * 4, hidden_dim) |
|
|
), |
|
|
'norm2': nn.LayerNorm(hidden_dim) |
|
|
}) for _ in range(n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.shared_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim * 2, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
|
|
|
self.regression_head = nn.Linear(hidden_dim, 1) |
|
|
|
|
|
|
|
|
self.classification_head = nn.Linear(hidden_dim, 3) |
|
|
|
|
|
def get_binding_class(self, affinity): |
|
|
"""Convert affinity values to class indices |
|
|
0: tight binding (>= 7.5) |
|
|
1: medium binding (6.0-7.5) |
|
|
2: weak binding (< 6.0) |
|
|
""" |
|
|
if isinstance(affinity, torch.Tensor): |
|
|
tight_mask = affinity >= self.tight_threshold |
|
|
weak_mask = affinity < self.weak_threshold |
|
|
medium_mask = ~(tight_mask | weak_mask) |
|
|
|
|
|
classes = torch.zeros_like(affinity, dtype=torch.long) |
|
|
classes[medium_mask] = 1 |
|
|
classes[weak_mask] = 2 |
|
|
return classes |
|
|
else: |
|
|
if affinity >= self.tight_threshold: |
|
|
return 0 |
|
|
elif affinity < self.weak_threshold: |
|
|
return 2 |
|
|
else: |
|
|
return 1 |
|
|
|
|
|
def forward(self, protein_emb, smiles_emb): |
|
|
protein = self.protein_norm(self.protein_projection(protein_emb)) |
|
|
smiles = self.smiles_norm(self.smiles_projection(smiles_emb)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for layer in self.cross_attention_layers: |
|
|
|
|
|
attended_protein = layer['attention']( |
|
|
protein, smiles, smiles |
|
|
)[0] |
|
|
protein = layer['norm1'](protein + attended_protein) |
|
|
protein = layer['norm2'](protein + layer['ffn'](protein)) |
|
|
|
|
|
|
|
|
attended_smiles = layer['attention']( |
|
|
smiles, protein, protein |
|
|
)[0] |
|
|
smiles = layer['norm1'](smiles + attended_smiles) |
|
|
smiles = layer['norm2'](smiles + layer['ffn'](smiles)) |
|
|
|
|
|
|
|
|
protein_pool = torch.mean(protein, dim=0) |
|
|
smiles_pool = torch.mean(smiles, dim=0) |
|
|
|
|
|
|
|
|
combined = torch.cat([protein_pool, smiles_pool], dim=-1) |
|
|
|
|
|
|
|
|
shared_features = self.shared_head(combined) |
|
|
|
|
|
regression_output = self.regression_head(shared_features) |
|
|
classification_logits = self.classification_head(shared_features) |
|
|
|
|
|
return regression_output, classification_logits |
|
|
|
|
|
class BindingAffinity: |
|
|
def __init__(self, prot_seq, model_type='PeptideCLM'): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer |
|
|
self.pep_tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/functions/tokenizer/new_vocab.txt', |
|
|
f'{base_path}/functions/tokenizer/new_splits.txt') |
|
|
self.model = ImprovedBindingPredictor() |
|
|
checkpoint = torch.load(f'{base_path}/src/binding/best_model.pt', weights_only=False) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
|
self.prot_tokenizer = alphabet.get_batch_converter() |
|
|
|
|
|
data = [("target", prot_seq)] |
|
|
|
|
|
_, _, prot_tokens = self.prot_tokenizer(data) |
|
|
with torch.no_grad(): |
|
|
results = self.esm_model.forward(prot_tokens, repr_layers=[33]) |
|
|
prot_emb = results["representations"][33] |
|
|
|
|
|
self.prot_emb = prot_emb[0] |
|
|
self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
def forward(self, input_seqs): |
|
|
with torch.no_grad(): |
|
|
scores = [] |
|
|
for seq in input_seqs: |
|
|
pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
emb = self.pep_model(input_ids=pep_tokens['input_ids'], |
|
|
attention_mask=pep_tokens['attention_mask'], |
|
|
output_hidden_states=True) |
|
|
|
|
|
|
|
|
pep_emb = emb.last_hidden_state.squeeze(0) |
|
|
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True) |
|
|
|
|
|
score, logits = self.model.forward(self.prot_emb, pep_emb) |
|
|
scores.append(score.item()) |
|
|
return scores |
|
|
|
|
|
def __call__(self, input_seqs: list): |
|
|
return self.forward(input_seqs) |
|
|
|
|
|
def unittest(): |
|
|
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV' |
|
|
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF' |
|
|
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM' |
|
|
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS' |
|
|
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM' |
|
|
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF' |
|
|
|
|
|
binding = BindingAffinity(tfr) |
|
|
|
|
|
seq = ["CC[C@H](C)[C@H](NC(=O)[C@H](C)NC(=O)[C@@H](N)Cc1c[nH]cn1)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)O"] |
|
|
|
|
|
scores = binding(seq) |
|
|
print(scores) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest() |