Brilleslangen commited on
Commit
aec9df8
·
1 Parent(s): b385c11

App for cleavage site prediction.

Browse files
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import pandas as pd
4
+ from transformers import EsmTokenizer
5
+ from model import CleavageSiteModel
6
+
7
+ # Load tokenizer and model
8
+ tokenizer = EsmTokenizer.from_pretrained("tokenizer") # Path to tokenizer folder
9
+ model = CleavageSiteModel(num_classes=75, base_model="facebook/esm2_t30_150M_UR50D")
10
+ model.load_state_dict(torch.load("model.pt", map_location="cpu"))
11
+ model.eval()
12
+
13
+ # Load example sequences and labels from CSV
14
+ examples_df = pd.read_csv("example_inputs.csv")
15
+ examples = examples_df[["sequence", "cleavage_site"]].values.tolist()
16
+
17
+
18
+ # Inference function accepting both sequence and true label
19
+ def predict(sequence, true_site):
20
+ inputs = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+ logits = outputs["logits"]
24
+ prediction = logits.argmax(dim=1).item()
25
+ return f"Predicted cleavage site index: {prediction} (True: {true_site})"
26
+
27
+
28
+ # Launch Gradio interface
29
+ gr.Interface(
30
+ fn=predict,
31
+ inputs=[
32
+ gr.Textbox(label="Protein Sequence", lines=2),
33
+ gr.Number(label="True Cleavage Site")
34
+ ],
35
+ outputs=gr.Textbox(label="Model Output"),
36
+ examples=examples,
37
+ title="Signal Peptide Cleavage Site Predictor",
38
+ description="Select an example or enter your own protein sequence and (optionally) its known cleavage site index."
39
+ ).launch()
example_inputs.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sequence,cleavage_site
2
+ MKAVITLLFLACILVVTYGDLICGTNYCKDHPCTSPIARASCRSPATYRANHSGKCACCPACVTLLRERA,18
3
+ MKIILTLSIFLICFLQLGQSVIDPSQNEVMSDLLFNLYGYDKSLDPCNSNSVECDDINSTSTIKTVISLN,19
4
+ MKHLLTLALCFSSINAVAVTVPHKAVGTGIPEGSLQFLSLRASAPIGSAISRNNWAVTCDSAQSGNECNK,23
5
+ MLFKSLSKLATAAAFFAGVATADDVPAIEVVGNKFFYSNNGSQFYIRGVAYQADTANETSGSTVNDPLAN,21
6
+ MVRPKHQPGGLCLLLLLLCQFMEDRSAQAGNCWLRQAKNGRCQVLYKTELSKEECCSTGRLSTSWTEEDV,28
7
+ METVLILCSLLAPVVLASAAEKEKEKDPFYYDYQTLRIGGLVFAVVLFSVGILLILSRRCKCSFNQKPRA,16
8
+ MKIILILSIFLICFLQLGQSVIDPSQNEVMSDLLFNLYGYDKSLDPCNNNYVECEYINTTSTIQTVKSLS,19
9
+ MERPVPSRLVPLPLLLLSSLSLLAARANADISMEACCTDGNQMANQHRDCSLPYTSESKECRMVQEQCCH,28
10
+ MKSFVLLFCLAQLWGCHSIPLDPVAGYKEPACDDPDTEQAALAAVDYINKHLPRGYKHTLNQIDSVKVWP,17
11
+ MLSLRVACLILSLASTVWTADTGTTSEFIEAGGDIRGPRIVERQPSQCKETDWPFCSDEDWNHKCPSGCR,18
12
+ MNSVLFLTLAVCSSLAYGKEFVATVRQNYKENINQLLEQQIQKELAASYIYQAYASYFQRADVSLPGIKK,17
13
+ MKSVQFCFLFCCWRAICCRSCELTNITITVEKEECSFCISINTTWCAGYCYTRDLVYKDPARPNIQKACT,18
14
+ MVRARHQPGGLCLLLLLLCQFMEDRSAQAGNCWLRQAKNGRCQVLYKTELSKEECCSTGRLSTSWTEEDV,28
15
+ MNSLVALVLLGQIIGSTLSSQVRGDLECDEKDAKEWTDTGVRYINEHKLHGYKYALNVIKNIVVVPWDGD,18
16
+ MVKFLLLALALGVSCAHYQNLEVSPSEVDGKWYSLYIAADNKEKVSEGGPLRAYIKNVECIDECQTLKIT,15
17
+ MWLLVSVILISRISSVGGEAMFCDFPKINHGILYDEEKYKPFSQVPTGEVFYYSCEYNFVSPSKSFWTRI,17
18
+ MKPIFLVLLVATSAYAAPSVTINQYSDNEIPRDIDDGKASSVISRAWDYVDDTDKSIAILNVQEILKDMA,15
19
+ MARNMNILTLFAVLIGSASAVYHPPSWTAWIAPKPWTAWKVHPPAWTAWKAHPPAWTAWKATPKPWTAWK,19
20
+ MAEWLLSASWQRRAKAMTAAAGSAGRAAVPLLLCALLAPGGAYVLDDSDGLGREFDGIGAVSGGGATSRL,41
21
+ MQRLCVYVLIFALALAAFSEASWKPRSQQPDAPLGTGANRDLELPWLEQQGPASHHRRQLGPQGPPHLVA,20
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b255bb785120a15e0046580dd94f1f982c487031d8ceb8ae53eec4f5e33b30b7
3
+ size 595544586
model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from transformers import EsmModel, AutoModel, PreTrainedModel, AutoConfig
5
+ import evaluate
6
+ import numpy as np
7
+ from sklearn.metrics import accuracy_score, classification_report
8
+ import wandb
9
+
10
+ accuracy_metric = evaluate.load("accuracy")
11
+ precision_metric = evaluate.load("precision")
12
+ recall_metric = evaluate.load("recall")
13
+ f1_metric = evaluate.load("f1")
14
+
15
+
16
+ class CleavageSiteModel(nn.Module):
17
+ def __init__(self, base_model, num_classes=75, class_weights=None):
18
+ super().__init__()
19
+ self.model = EsmModel.from_pretrained(base_model)
20
+ self.classifier = nn.Linear(self.model.config.hidden_size, num_classes)
21
+
22
+ if class_weights is not None:
23
+ # Create full-length weights tensor with zeros
24
+ weight_tensor = torch.zeros(num_classes)
25
+ for class_idx, weight in class_weights.items():
26
+ weight_tensor[class_idx] = weight
27
+ self.loss_fn = nn.CrossEntropyLoss(weight=weight_tensor)
28
+ else:
29
+ self.loss_fn = nn.CrossEntropyLoss()
30
+
31
+ def forward(self, input_ids, attention_mask, labels=None):
32
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
33
+ cls_output = outputs.last_hidden_state[:, 0]
34
+ logits = self.classifier(cls_output)
35
+
36
+ if labels is not None:
37
+ loss = self.loss_fn(logits, labels)
38
+ return {"loss": loss, "logits": logits}
39
+ else:
40
+ return {"logits": logits}
41
+
42
+
43
+ def compute_metrics(eval_pred):
44
+ # Computes classification metrics including overall accuracy and per-class accuracy.
45
+
46
+ logits, labels = eval_pred # Extract model outputs and labels
47
+ predictions = np.argmax(logits, axis=1) # Get predicted class
48
+
49
+ # Compute overall accuracy
50
+ accuracy = accuracy_score(labels, predictions)
51
+
52
+ report = classification_report(labels, predictions, digits=4)
53
+ wandb.log({"classification_report": wandb.Html(report.replace('\n', '<br>'))})
54
+
55
+ # Compute per-class accuracy
56
+ unique_classes = np.unique(labels)
57
+ per_class_acc = {}
58
+ for cls in unique_classes:
59
+ class_mask = labels == cls # Select samples belonging to this class
60
+ per_class_acc[f"accuracy_class_{cls}"] = (predictions[class_mask] == labels[class_mask]).mean()
61
+
62
+ # Log metrics
63
+ wandb.log({"overall_accuracy": accuracy, **per_class_acc})
64
+
65
+ return {"accuracy": accuracy, **per_class_acc}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ evaluate
5
+ scikit-learn
6
+ wandb
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "<cls>",
3
+ "eos_token": "<eos>",
4
+ "mask_token": "<mask>",
5
+ "pad_token": "<pad>",
6
+ "unk_token": "<unk>"
7
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<cls>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "32": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "<cls>",
46
+ "eos_token": "<eos>",
47
+ "extra_special_tokens": {},
48
+ "mask_token": "<mask>",
49
+ "model_max_length": 1000000000000000019884624838656,
50
+ "pad_token": "<pad>",
51
+ "tokenizer_class": "EsmTokenizer",
52
+ "unk_token": "<unk>"
53
+ }
tokenizer/vocab.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <cls>
2
+ <pad>
3
+ <eos>
4
+ <unk>
5
+ L
6
+ A
7
+ G
8
+ V
9
+ S
10
+ E
11
+ R
12
+ T
13
+ I
14
+ D
15
+ P
16
+ K
17
+ Q
18
+ N
19
+ F
20
+ Y
21
+ M
22
+ H
23
+ W
24
+ C
25
+ X
26
+ B
27
+ U
28
+ Z
29
+ O
30
+ .
31
+ -
32
+ <null_1>
33
+ <mask>