Update README.md
Browse files
README.md
CHANGED
|
@@ -54,7 +54,7 @@ esm_dict = {
|
|
| 54 |
"VESM_3B": 'facebook/esm2_t36_3B_UR50D',
|
| 55 |
"VESM3": "esm3_sm_open_v1"
|
| 56 |
}
|
| 57 |
-
def load_vesm(model_name="
|
| 58 |
if model_name in esm_dict:
|
| 59 |
ckt = esm_dict[model_name]
|
| 60 |
else:
|
|
@@ -66,7 +66,7 @@ def load_vesm(model_name="VESM_35M", local_dir="vesm", device='cuda'):
|
|
| 66 |
# load base model
|
| 67 |
if model_name == "VESM3":
|
| 68 |
from esm.models.esm3 import ESM3
|
| 69 |
-
model = ESM3.from_pretrained(
|
| 70 |
tokenizer = model.tokenizers.sequence
|
| 71 |
else:
|
| 72 |
model = EsmForMaskedLM.from_pretrained(ckt).to(device)
|
|
@@ -131,7 +131,7 @@ def inference(model, tokenizer, sequence, device):
|
|
| 131 |
Prediction with VESM models
|
| 132 |
"""
|
| 133 |
# load vesm models
|
| 134 |
-
model_name = '
|
| 135 |
model, tokenizer = load_vesm(model_name, local_dir=local_dir, device=device)
|
| 136 |
sequence_vocabs = tokenizer.get_vocab()
|
| 137 |
# inference
|
|
@@ -145,9 +145,9 @@ print(f"Predicted score by {model_name}: ", mutant_score)
|
|
| 145 |
```py
|
| 146 |
from esm.sdk.api import ESMProtein
|
| 147 |
|
| 148 |
-
# A sample structure pdb
|
| 149 |
-
# !wget https://alphafold.ebi.ac.uk/files/AF-P32245-F1-
|
| 150 |
-
pdb_file = "AF-P32245-F1-
|
| 151 |
protein = ESMProtein.from_pdb(pdb_file)
|
| 152 |
mutant = "M1Y:V2T"
|
| 153 |
```
|
|
@@ -166,7 +166,7 @@ with torch.no_grad():
|
|
| 166 |
logits = outs.sequence_logits[0, :, :]
|
| 167 |
input_ids = tokens.sequence
|
| 168 |
|
| 169 |
-
#
|
| 170 |
llrs = get_llrs(logits, input_ids)
|
| 171 |
# compute mutant score
|
| 172 |
mutant_score = score_mutant(llrs, mutant, sequence_vocabs)
|
|
|
|
| 54 |
"VESM_3B": 'facebook/esm2_t36_3B_UR50D',
|
| 55 |
"VESM3": "esm3_sm_open_v1"
|
| 56 |
}
|
| 57 |
+
def load_vesm(model_name="VESM_3B", local_dir="vesm", device='cuda'):
|
| 58 |
if model_name in esm_dict:
|
| 59 |
ckt = esm_dict[model_name]
|
| 60 |
else:
|
|
|
|
| 66 |
# load base model
|
| 67 |
if model_name == "VESM3":
|
| 68 |
from esm.models.esm3 import ESM3
|
| 69 |
+
model = ESM3.from_pretrained(ckt, device=device).to(torch.float)
|
| 70 |
tokenizer = model.tokenizers.sequence
|
| 71 |
else:
|
| 72 |
model = EsmForMaskedLM.from_pretrained(ckt).to(device)
|
|
|
|
| 131 |
Prediction with VESM models
|
| 132 |
"""
|
| 133 |
# load vesm models
|
| 134 |
+
model_name = 'VESM_3B'
|
| 135 |
model, tokenizer = load_vesm(model_name, local_dir=local_dir, device=device)
|
| 136 |
sequence_vocabs = tokenizer.get_vocab()
|
| 137 |
# inference
|
|
|
|
| 145 |
```py
|
| 146 |
from esm.sdk.api import ESMProtein
|
| 147 |
|
| 148 |
+
# A sample structure pdb: download the latest version
|
| 149 |
+
# !wget https://alphafold.ebi.ac.uk/files/AF-P32245-F1-model_v6.pdb
|
| 150 |
+
pdb_file = "AF-P32245-F1-model_v6.pdb"
|
| 151 |
protein = ESMProtein.from_pdb(pdb_file)
|
| 152 |
mutant = "M1Y:V2T"
|
| 153 |
```
|
|
|
|
| 166 |
logits = outs.sequence_logits[0, :, :]
|
| 167 |
input_ids = tokens.sequence
|
| 168 |
|
| 169 |
+
# calculate log-likelihood ratio from the logits
|
| 170 |
llrs = get_llrs(logits, input_ids)
|
| 171 |
# compute mutant score
|
| 172 |
mutant_score = score_mutant(llrs, mutant, sequence_vocabs)
|