|
|
import os |
|
|
import sys |
|
|
import warnings |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from transformers import DataCollatorWithPadding |
|
|
from transformers import EsmTokenizer |
|
|
from datasets import ( |
|
|
load_dataset, |
|
|
Dataset, |
|
|
) |
|
|
|
|
|
from modeling_esm import EsmForSequenceClassificationCustomWidehead |
|
|
|
|
|
|
|
|
print("intilizing checkpoint --might take a few min if this is the first time--") |
|
|
tokenizer = EsmTokenizer.from_pretrained("finalCheckpoint_25_05_11/") |
|
|
model = EsmForSequenceClassificationCustomWidehead.from_pretrained("finalCheckpoint_25_05_11/", num_labels=54).cuda() |
|
|
print("finished downloading") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
labsoi=set() |
|
|
lab2map={} |
|
|
labsoi.add("S_Phosphorylation") |
|
|
lab2map["S_Phosphorylation"]=0 |
|
|
labsoi.add("T_Phosphorylation") |
|
|
lab2map["T_Phosphorylation"]=1 |
|
|
labsoi.add("Y_Phosphorylation") |
|
|
lab2map["Y_Phosphorylation"]=3 |
|
|
labsoi.add("A_Acetylation") |
|
|
lab2map["A_Acetylation"]=13 |
|
|
labsoi.add("M_Acetylation") |
|
|
lab2map["M_Acetylation"]=14 |
|
|
labsoi.add("K_Acetylation") |
|
|
lab2map["K_Acetylation"]=4 |
|
|
labsoi.add("K_Ubiquitination") |
|
|
lab2map["K_Ubiquitination"]=2 |
|
|
labsoi.add("S_O-linked-Glycosylation") |
|
|
lab2map["S_O-linked-Glycosylation"]=6 |
|
|
labsoi.add("T_O-linked-Glycosylation") |
|
|
lab2map["T_O-linked-Glycosylation"]=7 |
|
|
labsoi.add("N_N-linked-Glycosylation") |
|
|
lab2map["N_N-linked-Glycosylation"]=5 |
|
|
labsoi.add("K_Methylation") |
|
|
lab2map["K_Methylation"]=9 |
|
|
labsoi.add("R_Methylation") |
|
|
lab2map["R_Methylation"]=8 |
|
|
labsoi.add("K_Malonylation") |
|
|
lab2map["K_Malonylation"]=11 |
|
|
labsoi.add("K_Sumoylation") |
|
|
lab2map["K_Sumoylation"]=10 |
|
|
labsoi.add("C_Glutathionylation") |
|
|
lab2map["C_Glutathionylation"]=15 |
|
|
labsoi.add("P_Hydroxylation") |
|
|
lab2map["P_Hydroxylation"]=17 |
|
|
labsoi.add("K_Hydroxylation") |
|
|
lab2map["K_Hydroxylation"]=18 |
|
|
labsoi.add("C_S-palmitoylation") |
|
|
lab2map["C_S-palmitoylation"]=16 |
|
|
lab2map['M_Sulfoxidation']=12 |
|
|
pos2lab={} |
|
|
for lab in lab2map.keys(): |
|
|
pos=lab2map[lab] |
|
|
pos2lab[pos]=lab |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_function(examples): |
|
|
toks={} |
|
|
toks['input_ids']=[] |
|
|
toks['attention_mask']=[] |
|
|
|
|
|
for info in examples["pep"]: |
|
|
info=info.replace(".", "<mask>") |
|
|
t=tokenizer(info.replace("-", "<pad>")) |
|
|
toks['input_ids'].append(t['input_ids']) |
|
|
toks['attention_mask'].append(t['attention_mask']) |
|
|
|
|
|
|
|
|
return toks |
|
|
|
|
|
|
|
|
def getlab(elab,res): |
|
|
output=np.zeros((20)) |
|
|
if res=='S': |
|
|
output[0]=max(elab[:5]) |
|
|
output[1]=0 |
|
|
elif res=='T': |
|
|
output[0]=0 |
|
|
output[1]=max(elab[:5]) |
|
|
else: |
|
|
output[0]=0 |
|
|
output[1]=0 |
|
|
|
|
|
output[2]=max(elab[5:25]) |
|
|
|
|
|
output[3]=max(elab[25:26]) |
|
|
|
|
|
output[4]=max(elab[26:36]) |
|
|
|
|
|
output[5]=max(elab[36:37]) |
|
|
|
|
|
if res=='S': |
|
|
output[6]=max(elab[37:42]) |
|
|
output[7]=0 |
|
|
elif res=='T': |
|
|
output[6]=0 |
|
|
output[7]=max(elab[37:42]) |
|
|
else: |
|
|
output[6]=0 |
|
|
output[7]=0 |
|
|
|
|
|
if res=="R": |
|
|
output[8]=max(elab[42:46]) |
|
|
output[9]=0 |
|
|
elif res=="K": |
|
|
output[8]=0 |
|
|
output[9]=max(elab[42:46]) |
|
|
else: |
|
|
output[8]=0 |
|
|
output[9]=0 |
|
|
|
|
|
output[10]=max(elab[46:47]) |
|
|
|
|
|
output[11]=max(elab[47:48]) |
|
|
|
|
|
output[12]=max(elab[48:49]) |
|
|
|
|
|
if res=="A": |
|
|
output[13]=max(elab[49:50]) |
|
|
output[14]=0 |
|
|
elif res=="M": |
|
|
output[13]=0 |
|
|
output[14]=max(elab[49:50]) |
|
|
else: |
|
|
output[13]=0 |
|
|
output[14]=0 |
|
|
|
|
|
output[15]=max(elab[50:51]) |
|
|
|
|
|
output[16]=max(elab[51:52]) |
|
|
|
|
|
if res=="P": |
|
|
output[17]=max(elab[52:53]) |
|
|
output[18]=0 |
|
|
elif res=="K": |
|
|
output[17]=0 |
|
|
output[18]=max(elab[52:53]) |
|
|
else: |
|
|
output[17]=0 |
|
|
output[18]=0 |
|
|
|
|
|
output[19]=max(elab[53:54]) |
|
|
return(output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(input_batches): |
|
|
sig=nn.Sigmoid() |
|
|
outputpreds=[] |
|
|
r='\r' |
|
|
for i,batches in enumerate(input_batches): |
|
|
print(f"{i} / {len(input_batches)} batches done",end=r) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred=(sig(model(torch.tensor([tokenizer(batches)['input_ids']]).squeeze().cuda(),torch.tensor([tokenizer(batches)['attention_mask']]).squeeze().cuda())["logits"]).tolist()) |
|
|
|
|
|
for p in pred: |
|
|
|
|
|
outputpreds.append(p) |
|
|
return outputpreds |
|
|
|
|
|
def write_output(pred,listofpeps,file_output): |
|
|
hf=open(f"{file_output}",'w+') |
|
|
n="\n" |
|
|
writethisline="pep" |
|
|
for i in range(len(labsoi)): |
|
|
writethisline+=','+pos2lab[i] |
|
|
hf.write(writethisline+n) |
|
|
for p,ip in zip(pred,listofpeps): |
|
|
writethisline=f"{ip}" |
|
|
r=ip[10] |
|
|
|
|
|
easyreadlab=getlab(p,r) |
|
|
for sp in easyreadlab: |
|
|
writethisline+=f",{sp}" |
|
|
|
|
|
writethisline=writethisline[:]+n |
|
|
hf.write(writethisline) |
|
|
hf.close() |
|
|
|
|
|
|
|
|
DOC_HELP=''' |
|
|
Usage: python3 claspp_forward.py [OPTION]... --input INPUT [FASTA_FILE or TXT_FILE]... |
|
|
predict PTM events on peptides or full sequences |
|
|
|
|
|
Example 1: python3 claspp_forward.py -B 100 -S 0 -i random.txt |
|
|
Example 2: python3 claspp_forward.py -B 50 -S 1 -i random.fasta |
|
|
|
|
|
FASTA_FILE contain protein sequences in proper fasta or a2m format |
|
|
TXT_FILE cointain protien peptides 21 in length with the center |
|
|
residue being the PTM modification site |
|
|
|
|
|
|
|
|
Pattern selection and interpretation: |
|
|
-B, --batch_size (int) that describes how many predictions |
|
|
can be predicted at a time on the GPU |
|
|
(reduce if you get run out of GPU space) |
|
|
|
|
|
-S --scrape_fasta (int) should be a 1 or a 0 |
|
|
1 = read a fasta and scrape posible 21 peptides |
|
|
that can be modified by a PTM |
|
|
0 = read a txt file that has the 21mer already |
|
|
sperated and all peptides should be sperated by |
|
|
a '\\n' (can be faster) than fasta option |
|
|
|
|
|
-h --help your reading it right now |
|
|
|
|
|
-i --input location of the input fasta or txt |
|
|
|
|
|
-o --output location of the output csv |
|
|
|
|
|
|
|
|
Report bugs to: |
|
|
|
|
|
|
|
|
''' |
|
|
WARNING_MESSAGE=""" |
|
|
################################# |
|
|
PLEASE READ HELP MESSAGE TO ENSURE |
|
|
YOU KNOW HOW TO FORMAT/USE THE |
|
|
MODEL |
|
|
################################# |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
batch_size=50 |
|
|
scrape=0 |
|
|
file_output="output_predictions.csv" |
|
|
input_file="N/A" |
|
|
for i in range(len(sys.argv)-1): |
|
|
if sys.argv[i]=='--scrape_fasta' or sys.argv[i]=='-S': |
|
|
scrape = int(sys.argv[i+1]) |
|
|
if sys.argv[i]=='--batch_size' or sys.argv[i]=='-B': |
|
|
batch_size = int(sys.argv[i+1]) |
|
|
if sys.argv[i]=='--input' or sys.argv[i]=='-i': |
|
|
input_file = sys.argv[i+1] |
|
|
if sys.argv[i]=='--output' or sys.argv[i]=='-o': |
|
|
file_output = sys.argv[i+1] |
|
|
if sys.argv[i]=='-h' or sys.argv[i]=='--h' or sys.argv[i]=='-help' or sys.argv[i]=='--help' : |
|
|
print(DOC_HELP) |
|
|
if input_file=='N/A': |
|
|
print(WARNING_MESSAGE) |
|
|
print(DOC_HELP) |
|
|
return |
|
|
|
|
|
if scrape==0: |
|
|
|
|
|
listofpeps=[] |
|
|
rf=open(input_file,"r") |
|
|
lines=rf.readlines() |
|
|
for line in lines: |
|
|
pep=line[:-1] |
|
|
listofpeps.append(pep) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
listofpeps=[] |
|
|
acc2seq={} |
|
|
|
|
|
rf=open(input_file,"r") |
|
|
lines=rf.readlines() |
|
|
seq="" |
|
|
acc="" |
|
|
for line in lines: |
|
|
if line[0]=='>': |
|
|
if seq!='': |
|
|
acc2seq[acc]=seq |
|
|
|
|
|
seq="" |
|
|
acc=line[1:-1] |
|
|
else: |
|
|
seq+=line.replace('\n','') |
|
|
acc2seq[acc]=seq |
|
|
|
|
|
for acc in acc2seq.keys(): |
|
|
seq=acc2seq[acc] |
|
|
paddedseq='----------'+seq+'----------' |
|
|
for i,c in enumerate(seq): |
|
|
pep=paddedseq[i:i+21] |
|
|
listofpeps.append(pep) |
|
|
setofpeps=set(listofpeps) |
|
|
listofpeps=list(setofpeps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_batches=[] |
|
|
temp=[] |
|
|
for i,pep in enumerate(listofpeps): |
|
|
if i%batch_size==0 and i!=0: |
|
|
input_batches.append(temp) |
|
|
temp=[] |
|
|
if pep=='': |
|
|
continue |
|
|
temp.append(pep.replace("-", "<pad>")) |
|
|
input_batches.append(temp) |
|
|
|
|
|
|
|
|
pred=predict(input_batches=input_batches) |
|
|
write_output(pred,listofpeps,file_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|