Brilleslangen commited on
Commit ·
aec9df8
1
Parent(s): b385c11
App for cleavage site prediction.
Browse files- app.py +39 -0
- example_inputs.csv +21 -0
- model.pt +3 -0
- model.py +65 -0
- requirements.txt +6 -0
- tokenizer/special_tokens_map.json +7 -0
- tokenizer/tokenizer_config.json +53 -0
- tokenizer/vocab.txt +33 -0
app.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from transformers import EsmTokenizer
|
| 5 |
+
from model import CleavageSiteModel
|
| 6 |
+
|
| 7 |
+
# Load tokenizer and model
|
| 8 |
+
tokenizer = EsmTokenizer.from_pretrained("tokenizer") # Path to tokenizer folder
|
| 9 |
+
model = CleavageSiteModel(num_classes=75, base_model="facebook/esm2_t30_150M_UR50D")
|
| 10 |
+
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
|
| 11 |
+
model.eval()
|
| 12 |
+
|
| 13 |
+
# Load example sequences and labels from CSV
|
| 14 |
+
examples_df = pd.read_csv("example_inputs.csv")
|
| 15 |
+
examples = examples_df[["sequence", "cleavage_site"]].values.tolist()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Inference function accepting both sequence and true label
|
| 19 |
+
def predict(sequence, true_site):
|
| 20 |
+
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
outputs = model(**inputs)
|
| 23 |
+
logits = outputs["logits"]
|
| 24 |
+
prediction = logits.argmax(dim=1).item()
|
| 25 |
+
return f"Predicted cleavage site index: {prediction} (True: {true_site})"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Launch Gradio interface
|
| 29 |
+
gr.Interface(
|
| 30 |
+
fn=predict,
|
| 31 |
+
inputs=[
|
| 32 |
+
gr.Textbox(label="Protein Sequence", lines=2),
|
| 33 |
+
gr.Number(label="True Cleavage Site")
|
| 34 |
+
],
|
| 35 |
+
outputs=gr.Textbox(label="Model Output"),
|
| 36 |
+
examples=examples,
|
| 37 |
+
title="Signal Peptide Cleavage Site Predictor",
|
| 38 |
+
description="Select an example or enter your own protein sequence and (optionally) its known cleavage site index."
|
| 39 |
+
).launch()
|
example_inputs.csv
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sequence,cleavage_site
|
| 2 |
+
MKAVITLLFLACILVVTYGDLICGTNYCKDHPCTSPIARASCRSPATYRANHSGKCACCPACVTLLRERA,18
|
| 3 |
+
MKIILTLSIFLICFLQLGQSVIDPSQNEVMSDLLFNLYGYDKSLDPCNSNSVECDDINSTSTIKTVISLN,19
|
| 4 |
+
MKHLLTLALCFSSINAVAVTVPHKAVGTGIPEGSLQFLSLRASAPIGSAISRNNWAVTCDSAQSGNECNK,23
|
| 5 |
+
MLFKSLSKLATAAAFFAGVATADDVPAIEVVGNKFFYSNNGSQFYIRGVAYQADTANETSGSTVNDPLAN,21
|
| 6 |
+
MVRPKHQPGGLCLLLLLLCQFMEDRSAQAGNCWLRQAKNGRCQVLYKTELSKEECCSTGRLSTSWTEEDV,28
|
| 7 |
+
METVLILCSLLAPVVLASAAEKEKEKDPFYYDYQTLRIGGLVFAVVLFSVGILLILSRRCKCSFNQKPRA,16
|
| 8 |
+
MKIILILSIFLICFLQLGQSVIDPSQNEVMSDLLFNLYGYDKSLDPCNNNYVECEYINTTSTIQTVKSLS,19
|
| 9 |
+
MERPVPSRLVPLPLLLLSSLSLLAARANADISMEACCTDGNQMANQHRDCSLPYTSESKECRMVQEQCCH,28
|
| 10 |
+
MKSFVLLFCLAQLWGCHSIPLDPVAGYKEPACDDPDTEQAALAAVDYINKHLPRGYKHTLNQIDSVKVWP,17
|
| 11 |
+
MLSLRVACLILSLASTVWTADTGTTSEFIEAGGDIRGPRIVERQPSQCKETDWPFCSDEDWNHKCPSGCR,18
|
| 12 |
+
MNSVLFLTLAVCSSLAYGKEFVATVRQNYKENINQLLEQQIQKELAASYIYQAYASYFQRADVSLPGIKK,17
|
| 13 |
+
MKSVQFCFLFCCWRAICCRSCELTNITITVEKEECSFCISINTTWCAGYCYTRDLVYKDPARPNIQKACT,18
|
| 14 |
+
MVRARHQPGGLCLLLLLLCQFMEDRSAQAGNCWLRQAKNGRCQVLYKTELSKEECCSTGRLSTSWTEEDV,28
|
| 15 |
+
MNSLVALVLLGQIIGSTLSSQVRGDLECDEKDAKEWTDTGVRYINEHKLHGYKYALNVIKNIVVVPWDGD,18
|
| 16 |
+
MVKFLLLALALGVSCAHYQNLEVSPSEVDGKWYSLYIAADNKEKVSEGGPLRAYIKNVECIDECQTLKIT,15
|
| 17 |
+
MWLLVSVILISRISSVGGEAMFCDFPKINHGILYDEEKYKPFSQVPTGEVFYYSCEYNFVSPSKSFWTRI,17
|
| 18 |
+
MKPIFLVLLVATSAYAAPSVTINQYSDNEIPRDIDDGKASSVISRAWDYVDDTDKSIAILNVQEILKDMA,15
|
| 19 |
+
MARNMNILTLFAVLIGSASAVYHPPSWTAWIAPKPWTAWKVHPPAWTAWKAHPPAWTAWKATPKPWTAWK,19
|
| 20 |
+
MAEWLLSASWQRRAKAMTAAAGSAGRAAVPLLLCALLAPGGAYVLDDSDGLGREFDGIGAVSGGGATSRL,41
|
| 21 |
+
MQRLCVYVLIFALALAAFSEASWKPRSQQPDAPLGTGANRDLELPWLEQQGPASHHRRQLGPQGPPHLVA,20
|
model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b255bb785120a15e0046580dd94f1f982c487031d8ceb8ae53eec4f5e33b30b7
|
| 3 |
+
size 595544586
|
model.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import EsmModel, AutoModel, PreTrainedModel, AutoConfig
|
| 5 |
+
import evaluate
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sklearn.metrics import accuracy_score, classification_report
|
| 8 |
+
import wandb
|
| 9 |
+
|
| 10 |
+
accuracy_metric = evaluate.load("accuracy")
|
| 11 |
+
precision_metric = evaluate.load("precision")
|
| 12 |
+
recall_metric = evaluate.load("recall")
|
| 13 |
+
f1_metric = evaluate.load("f1")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CleavageSiteModel(nn.Module):
|
| 17 |
+
def __init__(self, base_model, num_classes=75, class_weights=None):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.model = EsmModel.from_pretrained(base_model)
|
| 20 |
+
self.classifier = nn.Linear(self.model.config.hidden_size, num_classes)
|
| 21 |
+
|
| 22 |
+
if class_weights is not None:
|
| 23 |
+
# Create full-length weights tensor with zeros
|
| 24 |
+
weight_tensor = torch.zeros(num_classes)
|
| 25 |
+
for class_idx, weight in class_weights.items():
|
| 26 |
+
weight_tensor[class_idx] = weight
|
| 27 |
+
self.loss_fn = nn.CrossEntropyLoss(weight=weight_tensor)
|
| 28 |
+
else:
|
| 29 |
+
self.loss_fn = nn.CrossEntropyLoss()
|
| 30 |
+
|
| 31 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
| 32 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| 33 |
+
cls_output = outputs.last_hidden_state[:, 0]
|
| 34 |
+
logits = self.classifier(cls_output)
|
| 35 |
+
|
| 36 |
+
if labels is not None:
|
| 37 |
+
loss = self.loss_fn(logits, labels)
|
| 38 |
+
return {"loss": loss, "logits": logits}
|
| 39 |
+
else:
|
| 40 |
+
return {"logits": logits}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compute_metrics(eval_pred):
|
| 44 |
+
# Computes classification metrics including overall accuracy and per-class accuracy.
|
| 45 |
+
|
| 46 |
+
logits, labels = eval_pred # Extract model outputs and labels
|
| 47 |
+
predictions = np.argmax(logits, axis=1) # Get predicted class
|
| 48 |
+
|
| 49 |
+
# Compute overall accuracy
|
| 50 |
+
accuracy = accuracy_score(labels, predictions)
|
| 51 |
+
|
| 52 |
+
report = classification_report(labels, predictions, digits=4)
|
| 53 |
+
wandb.log({"classification_report": wandb.Html(report.replace('\n', '<br>'))})
|
| 54 |
+
|
| 55 |
+
# Compute per-class accuracy
|
| 56 |
+
unique_classes = np.unique(labels)
|
| 57 |
+
per_class_acc = {}
|
| 58 |
+
for cls in unique_classes:
|
| 59 |
+
class_mask = labels == cls # Select samples belonging to this class
|
| 60 |
+
per_class_acc[f"accuracy_class_{cls}"] = (predictions[class_mask] == labels[class_mask]).mean()
|
| 61 |
+
|
| 62 |
+
# Log metrics
|
| 63 |
+
wandb.log({"overall_accuracy": accuracy, **per_class_acc})
|
| 64 |
+
|
| 65 |
+
return {"accuracy": accuracy, **per_class_acc}
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
gradio
|
| 4 |
+
evaluate
|
| 5 |
+
scikit-learn
|
| 6 |
+
wandb
|
tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "<cls>",
|
| 3 |
+
"eos_token": "<eos>",
|
| 4 |
+
"mask_token": "<mask>",
|
| 5 |
+
"pad_token": "<pad>",
|
| 6 |
+
"unk_token": "<unk>"
|
| 7 |
+
}
|
tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<cls>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<pad>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "<eos>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"32": {
|
| 36 |
+
"content": "<mask>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": false,
|
| 45 |
+
"cls_token": "<cls>",
|
| 46 |
+
"eos_token": "<eos>",
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"mask_token": "<mask>",
|
| 49 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 50 |
+
"pad_token": "<pad>",
|
| 51 |
+
"tokenizer_class": "EsmTokenizer",
|
| 52 |
+
"unk_token": "<unk>"
|
| 53 |
+
}
|
tokenizer/vocab.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<cls>
|
| 2 |
+
<pad>
|
| 3 |
+
<eos>
|
| 4 |
+
<unk>
|
| 5 |
+
L
|
| 6 |
+
A
|
| 7 |
+
G
|
| 8 |
+
V
|
| 9 |
+
S
|
| 10 |
+
E
|
| 11 |
+
R
|
| 12 |
+
T
|
| 13 |
+
I
|
| 14 |
+
D
|
| 15 |
+
P
|
| 16 |
+
K
|
| 17 |
+
Q
|
| 18 |
+
N
|
| 19 |
+
F
|
| 20 |
+
Y
|
| 21 |
+
M
|
| 22 |
+
H
|
| 23 |
+
W
|
| 24 |
+
C
|
| 25 |
+
X
|
| 26 |
+
B
|
| 27 |
+
U
|
| 28 |
+
Z
|
| 29 |
+
O
|
| 30 |
+
.
|
| 31 |
+
-
|
| 32 |
+
<null_1>
|
| 33 |
+
<mask>
|