Gokulerusappan AmelieSchreiber commited on
Commit
490b2a6
·
0 Parent(s):

Duplicate from AmelieSchreiber/cafa_5_protein_function_prediction

Browse files

Co-authored-by: Amelie Schreiber <AmelieSchreiber@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ tags:
7
+ - ems
8
+ - esm2
9
+ - biology
10
+ - protein
11
+ - protein language model
12
+ - cafa 5
13
+ - protein function prediction
14
+ datasets:
15
+ - AmelieSchreiber/cafa_5
16
+ metrics:
17
+ - f1
18
+ - recall
19
+ - precision
20
+ ---
21
+ # ESM-2 for Protein Function Prediction
22
+
23
+ Please also see the more recent fine-tuned model [AmelieSchreiber/esm2_t6_8M_finetuned_cafa5](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_finetuned_cafa5).
24
+
25
+ This model is not intended for protein function prediction, but rather as a checkpoint for further fine-tuning, especially
26
+ with Low Rank Adaptation (LoRA). This is an experimental model fine-tuned from the
27
+ [esm2_t6_8M_UR50D](https://huggingface.co/facebook/esm2_t6_8M_UR50D) model
28
+ for multi-label classification. In particular, the model is fine-tuned on the CAFA-5 protein sequence dataset available
29
+ [here](https://huggingface.co/datasets/AmelieSchreiber/cafa_5). More precisely, the `train_sequences.fasta` file is the
30
+ list of protein sequences that were trained on, and the
31
+ `train_terms.tsv` file contains the gene ontology protein function labels for each protein sequence. For more details on using
32
+ ESM-2 models for multi-label sequence classification, [see here](https://huggingface.co/docs/transformers/model_doc/esm).
33
+ Due to the potentially complicated class weighting necessary for the hierarchical ontology, further fine-tuning will be necessary.
34
+
35
+ ## Fine-Tuning
36
+
37
+ The model was fine-tuned for 7 epochs at a learning rate of `5e-5`, and achieves the following metrics:
38
+ ```
39
+ Validation Loss: 0.0027,
40
+ Validation Micro F1: 0.3672,
41
+ Validation Macro F1: 0.9967,
42
+ Validation Micro Precision: 0.6052,
43
+ Validation Macro Precision: 0.9996,
44
+ Validation Micro Recall: 0.2626,
45
+ Validation Macro Recall: 0.9966
46
+ ```
47
+
48
+ ## Using the model
49
+
50
+ First, download the `train_sequences.fasta` file and the `train_terms.tsv` file, and provide the local paths in the code below:
51
+
52
+ ```python
53
+ import os
54
+ import numpy as np
55
+ import torch
56
+ from transformers import AutoTokenizer, EsmForSequenceClassification, AdamW
57
+ from torch.nn.functional import binary_cross_entropy_with_logits
58
+ from sklearn.model_selection import train_test_split
59
+ from sklearn.metrics import f1_score, precision_score, recall_score
60
+ # from accelerate import Accelerator
61
+ from Bio import SeqIO
62
+
63
+ # Step 1: Data Preprocessing (Replace with your local paths)
64
+ fasta_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_sequences.fasta"
65
+ tsv_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_terms.tsv"
66
+
67
+ fasta_data = {}
68
+ tsv_data = {}
69
+
70
+ for record in SeqIO.parse(fasta_file, "fasta"):
71
+ fasta_data[record.id] = str(record.seq)
72
+
73
+ with open(tsv_file, 'r') as f:
74
+ for line in f:
75
+ parts = line.strip().split("\t")
76
+ tsv_data[parts[0]] = parts[1:]
77
+
78
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
79
+ seq_length = 1022
80
+ # tokenized_data = tokenizer(list(fasta_data.values()), padding=True, truncation=True, return_tensors="pt", max_length=seq_length)
81
+
82
+ unique_terms = list(set(term for terms in tsv_data.values() for term in terms))
83
+ ```
84
+
85
+
86
+ Second, downlowd the file `go-basic.obo` [from here](https://huggingface.co/datasets/AmelieSchreiber/cafa_5)
87
+ and store the file locally, then provide the local path in the the code below:
88
+
89
+ ```python
90
+ import torch
91
+ from transformers import AutoTokenizer, EsmForSequenceClassification
92
+ from sklearn.metrics import precision_recall_fscore_support
93
+
94
+ # 1. Parsing the go-basic.obo file
95
+ def parse_obo_file(file_path):
96
+ with open(file_path, 'r') as f:
97
+ data = f.read().split("[Term]")
98
+
99
+ terms = []
100
+ for entry in data[1:]:
101
+ lines = entry.strip().split("\n")
102
+ term = {}
103
+ for line in lines:
104
+ if line.startswith("id:"):
105
+ term["id"] = line.split("id:")[1].strip()
106
+ elif line.startswith("name:"):
107
+ term["name"] = line.split("name:")[1].strip()
108
+ elif line.startswith("namespace:"):
109
+ term["namespace"] = line.split("namespace:")[1].strip()
110
+ elif line.startswith("def:"):
111
+ term["definition"] = line.split("def:")[1].split('"')[1]
112
+ terms.append(term)
113
+ return terms
114
+
115
+ parsed_terms = parse_obo_file("go-basic.obo") # Replace `go-basic.obo` with your path
116
+
117
+ # 2. Load the saved model and tokenizer
118
+ model_path = "AmelieSchreiber/cafa_5_protein_function_prediction"
119
+ loaded_model = EsmForSequenceClassification.from_pretrained(model_path)
120
+ loaded_tokenizer = AutoTokenizer.from_pretrained(model_path)
121
+
122
+ # 3. The predict_protein_function function
123
+ def predict_protein_function(sequence, model, tokenizer, go_terms):
124
+ inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=1022)
125
+ model.eval()
126
+ with torch.no_grad():
127
+ outputs = model(**inputs)
128
+ predictions = torch.sigmoid(outputs.logits)
129
+ predicted_indices = torch.where(predictions > 0.05)[1].tolist()
130
+
131
+ functions = []
132
+ for idx in predicted_indices:
133
+ term_id = unique_terms[idx] # Use the unique_terms list from your training script
134
+ for term in go_terms:
135
+ if term["id"] == term_id:
136
+ functions.append(term["name"])
137
+ break
138
+
139
+ return functions
140
+
141
+ # 4. Predicting protein function for an example sequence
142
+ example_sequence = "MAYLGSLVQRRLELASGDRLEASLGVGSELDVRGDRVKAVGSLDLEEGRLEQAGVSMA" # Replace with your protein sequence
143
+ predicted_functions = predict_protein_function(example_sequence, loaded_model, loaded_tokenizer, parsed_terms)
144
+ print(predicted_functions)
145
+ ```
config.json ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b979922aa035be9cf44e205a6c018b31a2cd73f38cf09a8bdd957ec8f63d8adf
3
+ size 34583185
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "<cls>",
3
+ "eos_token": "<eos>",
4
+ "mask_token": "<mask>",
5
+ "pad_token": "<pad>",
6
+ "unk_token": "<unk>"
7
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "clean_up_tokenization_spaces": true,
3
+ "model_max_length": 1024,
4
+ "tokenizer_class": "EsmTokenizer"
5
+ }
vocab.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <cls>
2
+ <pad>
3
+ <eos>
4
+ <unk>
5
+ L
6
+ A
7
+ G
8
+ V
9
+ S
10
+ E
11
+ R
12
+ T
13
+ I
14
+ D
15
+ P
16
+ K
17
+ Q
18
+ N
19
+ F
20
+ Y
21
+ M
22
+ H
23
+ W
24
+ C
25
+ X
26
+ B
27
+ U
28
+ Z
29
+ O
30
+ .
31
+ -
32
+ <null_1>
33
+ <mask>