| | import gradio as gr |
| | import torch |
| | import pandas as pd |
| | from transformers import EsmTokenizer |
| | from model import CleavageSiteModel |
| |
|
| | |
| | tokenizer = EsmTokenizer.from_pretrained("tokenizer") |
| | model = CleavageSiteModel(num_classes=75, base_model="facebook/esm2_t30_150M_UR50D") |
| | model.load_state_dict(torch.load("model.pt", map_location="cpu")) |
| | model.eval() |
| |
|
| | |
| | examples_df = pd.read_csv("example_inputs.csv") |
| | examples = examples_df[["sequence", "cleavage_site"]].values.tolist() |
| |
|
| |
|
| | |
| | def predict(sequence, true_site): |
| | inputs = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True) |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | logits = outputs["logits"] |
| | prediction = logits.argmax(dim=1).item() |
| | return f"Predicted cleavage site index: {prediction} (True: {true_site})" |
| |
|
| |
|
| | |
| | gr.Interface( |
| | fn=predict, |
| | inputs=[ |
| | gr.Textbox(label="Protein Sequence", lines=2), |
| | gr.Number(label="True Cleavage Site") |
| | ], |
| | outputs=gr.Textbox(label="Model Output"), |
| | examples=examples, |
| | title="Signal Peptide Cleavage Site Predictor", |
| | description="Created by Nicolai Thorer Sivesind & Erlend Rønning \n Select an example or enter your own protein " |
| | " sequence and (optionally) its known cleavage site index." |
| | ).launch() |
| |
|