import gradio as gr import torch import pandas as pd from transformers import EsmTokenizer from model import CleavageSiteModel # Load tokenizer and model tokenizer = EsmTokenizer.from_pretrained("tokenizer") # Path to tokenizer folder model = CleavageSiteModel(num_classes=75, base_model="facebook/esm2_t30_150M_UR50D") model.load_state_dict(torch.load("model.pt", map_location="cpu")) model.eval() # Load example sequences and labels from CSV examples_df = pd.read_csv("example_inputs.csv") examples = examples_df[["sequence", "cleavage_site"]].values.tolist() # Inference function accepting both sequence and true label def predict(sequence, true_site): inputs = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs["logits"] prediction = logits.argmax(dim=1).item() return f"Predicted cleavage site index: {prediction} (True: {true_site})" # Launch Gradio interface gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Protein Sequence", lines=2), gr.Number(label="True Cleavage Site") ], outputs=gr.Textbox(label="Model Output"), examples=examples, title="Signal Peptide Cleavage Site Predictor", description="Created by Nicolai Thorer Sivesind & Erlend Rønning \n Select an example or enter your own protein " " sequence and (optionally) its known cleavage site index." ).launch()