Update README.md
Browse files
README.md
CHANGED
|
@@ -4,7 +4,7 @@ license: mit
|
|
| 4 |
We provide two ways to use SaProt, including through huggingface class and
|
| 5 |
through the same way as in [esm github](https://github.com/facebookresearch/esm). Users can choose either one to use.
|
| 6 |
|
| 7 |
-
##
|
| 8 |
The following code shows how to load the model.
|
| 9 |
```
|
| 10 |
from transformers import EsmTokenizer, EsmForMaskedLM
|
|
@@ -33,11 +33,88 @@ torch.Size([1, 11, 446])
|
|
| 33 |
"""
|
| 34 |
```
|
| 35 |
|
| 36 |
-
##
|
| 37 |
The esm version is also stored in the same folder, named `SaProt_650M_AF2.pt`. We provide a function to load the model.
|
| 38 |
```
|
| 39 |
from utils.esm_loader import load_esm_saprot
|
| 40 |
|
| 41 |
model_path = "/your/path/to/SaProt_650M_AF2.pt"
|
| 42 |
model, alphabet = load_esm_saprot(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
```
|
|
|
|
| 4 |
We provide two ways to use SaProt, including through huggingface class and
|
| 5 |
through the same way as in [esm github](https://github.com/facebookresearch/esm). Users can choose either one to use.
|
| 6 |
|
| 7 |
+
## Huggingface model
|
| 8 |
The following code shows how to load the model.
|
| 9 |
```
|
| 10 |
from transformers import EsmTokenizer, EsmForMaskedLM
|
|
|
|
| 33 |
"""
|
| 34 |
```
|
| 35 |
|
| 36 |
+
## esm model
|
| 37 |
The esm version is also stored in the same folder, named `SaProt_650M_AF2.pt`. We provide a function to load the model.
|
| 38 |
```
|
| 39 |
from utils.esm_loader import load_esm_saprot
|
| 40 |
|
| 41 |
model_path = "/your/path/to/SaProt_650M_AF2.pt"
|
| 42 |
model, alphabet = load_esm_saprot(model_path)
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Predict mutational effect
|
| 46 |
+
We provide a function to predict the mutational effect of a protein sequence. The example below shows how to predict
|
| 47 |
+
the mutational effect at a specific position. If using the AF2 structure, we strongly recommend that you add pLDDT mask (see below).
|
| 48 |
+
```python
|
| 49 |
+
from model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
config = {
|
| 53 |
+
"foldseek_path": None,
|
| 54 |
+
"config_path": "/your/path/to/SaProt_650M_AF2", # Note this is the directory path of SaProt, not the ".pt" file
|
| 55 |
+
"load_pretrained": True,
|
| 56 |
+
}
|
| 57 |
+
model = SaprotFoldseekMutationModel(**config)
|
| 58 |
+
tokenizer = model.tokenizer
|
| 59 |
+
|
| 60 |
+
device = "cuda"
|
| 61 |
+
model.eval()
|
| 62 |
+
model.to(device)
|
| 63 |
+
|
| 64 |
+
seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70)
|
| 65 |
+
|
| 66 |
+
# Predict the effect of mutating the 3rd amino acid to A
|
| 67 |
+
mut_info = "V3A"
|
| 68 |
+
mut_value = model.predict_mut(seq, mut_info)
|
| 69 |
+
print(mut_value)
|
| 70 |
+
|
| 71 |
+
# Predict all effects of mutations at 3rd position
|
| 72 |
+
mut_pos = 3
|
| 73 |
+
mut_dict = model.predict_pos_mut(seq, mut_pos)
|
| 74 |
+
print(mut_dict)
|
| 75 |
+
|
| 76 |
+
# Predict probabilities of all amino acids at 3rd position
|
| 77 |
+
mut_pos = 3
|
| 78 |
+
mut_dict = model.predict_pos_prob(seq, mut_pos)
|
| 79 |
+
print(mut_dict)
|
| 80 |
+
|
| 81 |
+
"""
|
| 82 |
+
0.7908501625061035
|
| 83 |
+
|
| 84 |
+
{'V3A': 0.7908501625061035, 'V3C': -0.9117952585220337, 'V3D': 2.7700226306915283, 'V3E': 2.3255627155303955, 'V3F': 0.2094242423772812, 'V3G': 2.699633836746216, 'V3H': 1.240191102027893, 'V3I': 0.10231903940439224, 'V3K': 1.804598093032837,
|
| 85 |
+
'V3L': 1.3324960470199585, 'V3M': -0.18938277661800385, 'V3N': 2.8249857425689697, 'V3P': 0.40185314416885376, 'V3Q': 1.8361762762069702, 'V3R': 1.1899691820144653, 'V3S': 2.2159857749938965, 'V3T': 0.8813426494598389, 'V3V': 0.0, 'V3W': 0.5853186249732971, 'V3Y': 0.17449656128883362}
|
| 86 |
+
|
| 87 |
+
{'A': 0.021275954321026802, 'C': 0.0038764977362006903, 'D': 0.15396881103515625, 'E': 0.0987202599644661, 'F': 0.011895398609340191, 'G': 0.14350374042987823, 'H': 0.03334535285830498, 'I': 0.010687196627259254, 'K': 0.058634623885154724, 'L': 0.03656982257962227, 'M': 0.00798324216157198, 'N': 0.16266827285289764, 'P': 0.014419485814869404, 'Q': 0.06051575019955635, 'R': 0.03171204403042793, 'S': 0.08847439289093018, 'T': 0.023291070014238358, 'V': 0.009647775441408157, 'W': 0.017323188483715057, 'Y': 0.011487090960144997}
|
| 88 |
+
"""
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## Get protein embeddings
|
| 92 |
+
If you want to generate protein embeddings, you could refer to the following code. The embeddings are the average of
|
| 93 |
+
the hidden states of the last layer.
|
| 94 |
+
```python
|
| 95 |
+
from model.saprot.base import SaprotBaseModel
|
| 96 |
+
from transformers import EsmTokenizer
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
config = {
|
| 100 |
+
"task": "base",
|
| 101 |
+
"config_path": "/your/path/to/SaProt_650M_AF2", # Note this is the directory path of SaProt, not the ".pt" file
|
| 102 |
+
"load_pretrained": True,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
model = SaprotBaseModel(**config)
|
| 106 |
+
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])
|
| 107 |
+
|
| 108 |
+
device = "cuda"
|
| 109 |
+
model.to(device)
|
| 110 |
+
|
| 111 |
+
seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70)
|
| 112 |
+
tokens = tokenizer.tokenize(seq)
|
| 113 |
+
print(tokens)
|
| 114 |
+
|
| 115 |
+
inputs = tokenizer(seq, return_tensors="pt")
|
| 116 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 117 |
+
|
| 118 |
+
embeddings = model.get_hidden_states(inputs, reduction="mean")
|
| 119 |
+
print(embeddings[0].shape)
|
| 120 |
```
|