Commit ·
490b2a6
0
Parent(s):
Duplicate from AmelieSchreiber/cafa_5_protein_function_prediction
Browse filesCo-authored-by: Amelie Schreiber <AmelieSchreiber@users.noreply.huggingface.co>
- .gitattributes +35 -0
- README.md +145 -0
- config.json +0 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +7 -0
- tokenizer_config.json +5 -0
- vocab.txt +33 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
+
tags:
|
| 7 |
+
- ems
|
| 8 |
+
- esm2
|
| 9 |
+
- biology
|
| 10 |
+
- protein
|
| 11 |
+
- protein language model
|
| 12 |
+
- cafa 5
|
| 13 |
+
- protein function prediction
|
| 14 |
+
datasets:
|
| 15 |
+
- AmelieSchreiber/cafa_5
|
| 16 |
+
metrics:
|
| 17 |
+
- f1
|
| 18 |
+
- recall
|
| 19 |
+
- precision
|
| 20 |
+
---
|
| 21 |
+
# ESM-2 for Protein Function Prediction
|
| 22 |
+
|
| 23 |
+
Please also see the more recent fine-tuned model [AmelieSchreiber/esm2_t6_8M_finetuned_cafa5](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_finetuned_cafa5).
|
| 24 |
+
|
| 25 |
+
This model is not intended for protein function prediction, but rather as a checkpoint for further fine-tuning, especially
|
| 26 |
+
with Low Rank Adaptation (LoRA). This is an experimental model fine-tuned from the
|
| 27 |
+
[esm2_t6_8M_UR50D](https://huggingface.co/facebook/esm2_t6_8M_UR50D) model
|
| 28 |
+
for multi-label classification. In particular, the model is fine-tuned on the CAFA-5 protein sequence dataset available
|
| 29 |
+
[here](https://huggingface.co/datasets/AmelieSchreiber/cafa_5). More precisely, the `train_sequences.fasta` file is the
|
| 30 |
+
list of protein sequences that were trained on, and the
|
| 31 |
+
`train_terms.tsv` file contains the gene ontology protein function labels for each protein sequence. For more details on using
|
| 32 |
+
ESM-2 models for multi-label sequence classification, [see here](https://huggingface.co/docs/transformers/model_doc/esm).
|
| 33 |
+
Due to the potentially complicated class weighting necessary for the hierarchical ontology, further fine-tuning will be necessary.
|
| 34 |
+
|
| 35 |
+
## Fine-Tuning
|
| 36 |
+
|
| 37 |
+
The model was fine-tuned for 7 epochs at a learning rate of `5e-5`, and achieves the following metrics:
|
| 38 |
+
```
|
| 39 |
+
Validation Loss: 0.0027,
|
| 40 |
+
Validation Micro F1: 0.3672,
|
| 41 |
+
Validation Macro F1: 0.9967,
|
| 42 |
+
Validation Micro Precision: 0.6052,
|
| 43 |
+
Validation Macro Precision: 0.9996,
|
| 44 |
+
Validation Micro Recall: 0.2626,
|
| 45 |
+
Validation Macro Recall: 0.9966
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Using the model
|
| 49 |
+
|
| 50 |
+
First, download the `train_sequences.fasta` file and the `train_terms.tsv` file, and provide the local paths in the code below:
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
import os
|
| 54 |
+
import numpy as np
|
| 55 |
+
import torch
|
| 56 |
+
from transformers import AutoTokenizer, EsmForSequenceClassification, AdamW
|
| 57 |
+
from torch.nn.functional import binary_cross_entropy_with_logits
|
| 58 |
+
from sklearn.model_selection import train_test_split
|
| 59 |
+
from sklearn.metrics import f1_score, precision_score, recall_score
|
| 60 |
+
# from accelerate import Accelerator
|
| 61 |
+
from Bio import SeqIO
|
| 62 |
+
|
| 63 |
+
# Step 1: Data Preprocessing (Replace with your local paths)
|
| 64 |
+
fasta_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_sequences.fasta"
|
| 65 |
+
tsv_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_terms.tsv"
|
| 66 |
+
|
| 67 |
+
fasta_data = {}
|
| 68 |
+
tsv_data = {}
|
| 69 |
+
|
| 70 |
+
for record in SeqIO.parse(fasta_file, "fasta"):
|
| 71 |
+
fasta_data[record.id] = str(record.seq)
|
| 72 |
+
|
| 73 |
+
with open(tsv_file, 'r') as f:
|
| 74 |
+
for line in f:
|
| 75 |
+
parts = line.strip().split("\t")
|
| 76 |
+
tsv_data[parts[0]] = parts[1:]
|
| 77 |
+
|
| 78 |
+
# tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
| 79 |
+
seq_length = 1022
|
| 80 |
+
# tokenized_data = tokenizer(list(fasta_data.values()), padding=True, truncation=True, return_tensors="pt", max_length=seq_length)
|
| 81 |
+
|
| 82 |
+
unique_terms = list(set(term for terms in tsv_data.values() for term in terms))
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
Second, downlowd the file `go-basic.obo` [from here](https://huggingface.co/datasets/AmelieSchreiber/cafa_5)
|
| 87 |
+
and store the file locally, then provide the local path in the the code below:
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
import torch
|
| 91 |
+
from transformers import AutoTokenizer, EsmForSequenceClassification
|
| 92 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 93 |
+
|
| 94 |
+
# 1. Parsing the go-basic.obo file
|
| 95 |
+
def parse_obo_file(file_path):
|
| 96 |
+
with open(file_path, 'r') as f:
|
| 97 |
+
data = f.read().split("[Term]")
|
| 98 |
+
|
| 99 |
+
terms = []
|
| 100 |
+
for entry in data[1:]:
|
| 101 |
+
lines = entry.strip().split("\n")
|
| 102 |
+
term = {}
|
| 103 |
+
for line in lines:
|
| 104 |
+
if line.startswith("id:"):
|
| 105 |
+
term["id"] = line.split("id:")[1].strip()
|
| 106 |
+
elif line.startswith("name:"):
|
| 107 |
+
term["name"] = line.split("name:")[1].strip()
|
| 108 |
+
elif line.startswith("namespace:"):
|
| 109 |
+
term["namespace"] = line.split("namespace:")[1].strip()
|
| 110 |
+
elif line.startswith("def:"):
|
| 111 |
+
term["definition"] = line.split("def:")[1].split('"')[1]
|
| 112 |
+
terms.append(term)
|
| 113 |
+
return terms
|
| 114 |
+
|
| 115 |
+
parsed_terms = parse_obo_file("go-basic.obo") # Replace `go-basic.obo` with your path
|
| 116 |
+
|
| 117 |
+
# 2. Load the saved model and tokenizer
|
| 118 |
+
model_path = "AmelieSchreiber/cafa_5_protein_function_prediction"
|
| 119 |
+
loaded_model = EsmForSequenceClassification.from_pretrained(model_path)
|
| 120 |
+
loaded_tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 121 |
+
|
| 122 |
+
# 3. The predict_protein_function function
|
| 123 |
+
def predict_protein_function(sequence, model, tokenizer, go_terms):
|
| 124 |
+
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=1022)
|
| 125 |
+
model.eval()
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
outputs = model(**inputs)
|
| 128 |
+
predictions = torch.sigmoid(outputs.logits)
|
| 129 |
+
predicted_indices = torch.where(predictions > 0.05)[1].tolist()
|
| 130 |
+
|
| 131 |
+
functions = []
|
| 132 |
+
for idx in predicted_indices:
|
| 133 |
+
term_id = unique_terms[idx] # Use the unique_terms list from your training script
|
| 134 |
+
for term in go_terms:
|
| 135 |
+
if term["id"] == term_id:
|
| 136 |
+
functions.append(term["name"])
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
return functions
|
| 140 |
+
|
| 141 |
+
# 4. Predicting protein function for an example sequence
|
| 142 |
+
example_sequence = "MAYLGSLVQRRLELASGDRLEASLGVGSELDVRGDRVKAVGSLDLEEGRLEQAGVSMA" # Replace with your protein sequence
|
| 143 |
+
predicted_functions = predict_protein_function(example_sequence, loaded_model, loaded_tokenizer, parsed_terms)
|
| 144 |
+
print(predicted_functions)
|
| 145 |
+
```
|
config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b979922aa035be9cf44e205a6c018b31a2cd73f38cf09a8bdd957ec8f63d8adf
|
| 3 |
+
size 34583185
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "<cls>",
|
| 3 |
+
"eos_token": "<eos>",
|
| 4 |
+
"mask_token": "<mask>",
|
| 5 |
+
"pad_token": "<pad>",
|
| 6 |
+
"unk_token": "<unk>"
|
| 7 |
+
}
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"clean_up_tokenization_spaces": true,
|
| 3 |
+
"model_max_length": 1024,
|
| 4 |
+
"tokenizer_class": "EsmTokenizer"
|
| 5 |
+
}
|
vocab.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<cls>
|
| 2 |
+
<pad>
|
| 3 |
+
<eos>
|
| 4 |
+
<unk>
|
| 5 |
+
L
|
| 6 |
+
A
|
| 7 |
+
G
|
| 8 |
+
V
|
| 9 |
+
S
|
| 10 |
+
E
|
| 11 |
+
R
|
| 12 |
+
T
|
| 13 |
+
I
|
| 14 |
+
D
|
| 15 |
+
P
|
| 16 |
+
K
|
| 17 |
+
Q
|
| 18 |
+
N
|
| 19 |
+
F
|
| 20 |
+
Y
|
| 21 |
+
M
|
| 22 |
+
H
|
| 23 |
+
W
|
| 24 |
+
C
|
| 25 |
+
X
|
| 26 |
+
B
|
| 27 |
+
U
|
| 28 |
+
Z
|
| 29 |
+
O
|
| 30 |
+
.
|
| 31 |
+
-
|
| 32 |
+
<null_1>
|
| 33 |
+
<mask>
|