NeuroPred-PLM / app.py
wangleiofficial
add new annotation
53a1e75
import torch
from NeuroPredPLM.predict import predict, batch_predict
import gradio as gr
from io import StringIO
from Bio import SeqIO
def classifier(peptide_seq):
handle = StringIO(peptide_seq)
data = []
for record in SeqIO.parse(handle, 'fasta'):
data.append((record.id, str(record.seq)))
device = "cuda" if torch.cuda.is_available() else "cpu"
neuropeptide_pred = predict(data, './model.pth', device)
return neuropeptide_pred
# {peptide_id:[Type:int(1->neuropeptide,0->non-neuropeptide), attention score:nd.array]}
def batch_classifier(file, cutoff):
data = []
for record in SeqIO.parse(file.name, 'fasta'):
data.append((record.id, str(record.seq)))
device = "cuda" if torch.cuda.is_available() else "cpu"
neuropeptide_pred = batch_predict(data, cutoff, './model.pth', device)
return neuropeptide_pred
with gr.Blocks() as demo:
gr.Markdown(" ## NeuroPred-PLM")
gr.Markdown("In this work, we developed an interpretable and robust neuropeptide prediction model, named NeuroPred-PLM. First, we employed a language model (ESM) of proteins to obtain semantic representations of neuropeptides, which could reduce the complexity of feature engineering. Next, we adopted a multi-scale convolutional neural network to enhance the local feature representation of neuropeptide embeddings. To make the model interpretable, we proposed a global multi-head attention network that could be used to capture the position-wise contribution to neuropeptide prediction via the attention scores. In addition, NeuroPred-PLM was developed based on our newly constructed NeuroPep 2.0 database. Benchmarks based on the independent test set show that NeuroPred-PLM achieves superior predictive performance compared to other state-of-the-art predictors.")
with gr.Tab("Single Sequence Medel"):
# cutoff = gr.Slider(0, 1, step=0.1, value=0.5, interactive=True)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Input single peptide sequence in the Fasta format",
lines=4,
value=">peptide-1\nIGLRLPNMLKF",
)
gr.Markdown("#### The input peptide sequence length should be between 5-100")
single_cutoff = gr.Slider(0, 1, step=0.1, value=0.5, interactive=True, label="Threshold")
text_button = gr.Button("Submit")
with gr.Column(scale=2):
gr.Markdown("Note: the output scores indicates the probability scores of the input sequence to be predicted as a neuropeptide or a non-neuropeptide.")
text_output = gr.outputs.Label(num_top_classes=2, label='Output')
with gr.Tab("Batch Model"):
with gr.Row():
with gr.Column(scale=2):
input_file_fasta = gr.File()
# cutoff = gr.Slider(0, 1, step=0.1, value=0.5, interactive=True, label="threshold")
# image_button = gr.Button("Submit")
with gr.Column(scale=2):
batch_cutoff = gr.Slider(0, 1, step=0.1, value=0.5, interactive=True, label="Threshold")
gr.Markdown("### Note")
gr.Markdown("- Limit the number of input sequences to less than 30")
gr.Markdown("- The file should be the Fasta format")
gr.Markdown("- The input peptide sequence length should be between 5-100")
image_button = gr.Button("Submit")
with gr.Column():
# gr.Markdown(" ### Flip text or image files using this demo.")
gr.Markdown("Note: the output scores indicates the probability scores of the input sequence to be predicted as a neuropeptide or a non-neuropeptide.")
frame_output = gr.DataFrame(headers=["Sequence Id", "Sequence", "Probability of neuropeptides", "Neuropeptide"],
datatype=["str", "str", "str", 'str'],)
with gr.Accordion("Citation"):
gr.Markdown("- Wang, L., Huang, C., Wang, M., Xue, Z., & Wang, Y. (2022). NeuroPred-PLM: an interpretable and robust model for neuropeptide prediction by protein language model. In preparation.")
gr.Markdown("- GitHub: https://github.com/ISYSLAB-HUST/NeuroPred-PLM")
with gr.Accordion("License"):
gr.Markdown("- Released under the [MIT license](https://github.com/ISYSLAB-HUST/NeuroPred-PLM/blob/main/LICENSE). ")
with gr.Accordion("Contact"):
gr.Markdown("- If you have any questions, comments, or would like to report a bug, please file a Github issue or contact me at wanglei94@hust.edu.cn.")
text_button.click(classifier, inputs=text_input, outputs=text_output)
image_button.click(batch_classifier, inputs=[input_file_fasta, batch_cutoff], outputs=frame_output)
demo.queue(4)
demo.launch()