| | --- |
| | license: mit |
| | --- |
| | |
| | ##### 🔴 <font color=red>Note: SaProt requires structural (SA token) input for optimal performance. AA-sequence-only mode works but must be finetuned - frozen embeddings work only for SA, not AA sequences! With structural input, SaProt surpasses ESM2 in most tasks.</font> |
| |
|
| | We provide two ways to use SaProt, including through huggingface class and |
| | through the same way as in [esm github](https://github.com/facebookresearch/esm). Users can choose either one to use. |
| |
|
| | ## Huggingface model |
| | The following code shows how to load the model. |
| | ``` |
| | from transformers import EsmTokenizer, EsmForMaskedLM |
| | |
| | model_path = "/your/path/to/SaProt_650M_AF2" |
| | tokenizer = EsmTokenizer.from_pretrained(model_path) |
| | model = EsmForMaskedLM.from_pretrained(model_path) |
| | |
| | #################### Example #################### |
| | device = "cuda" |
| | model.to(device) |
| | |
| | seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70) |
| | tokens = tokenizer.tokenize(seq) |
| | print(tokens) |
| | |
| | inputs = tokenizer(seq, return_tensors="pt") |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | |
| | outputs = model(**inputs) |
| | print(outputs.logits.shape) |
| | |
| | """ |
| | ['M#', 'Ev', 'Vp', 'Qp', 'L#', 'Vy', 'Qd', 'Ya', 'Kv'] |
| | torch.Size([1, 11, 446]) |
| | """ |
| | ``` |
| |
|
| | ## esm model |
| | The esm version is also stored in the same folder, named `SaProt_650M_AF2.pt`. We provide a function to load the model. |
| | ``` |
| | from utils.esm_loader import load_esm_saprot |
| | |
| | model_path = "/your/path/to/SaProt_650M_AF2.pt" |
| | model, alphabet = load_esm_saprot(model_path) |
| | ``` |
| |
|
| | ## Predict mutational effect |
| | We provide a function to predict the mutational effect of a protein sequence. The example below shows how to predict |
| | the mutational effect at a specific position. If using the AF2 structure, we strongly recommend that you add pLDDT mask (see below). |
| | ```python |
| | from model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel |
| | |
| | |
| | config = { |
| | "foldseek_path": None, |
| | "config_path": "/your/path/to/SaProt_650M_AF2", # Note this is the directory path of SaProt, not the ".pt" file |
| | "load_pretrained": True, |
| | } |
| | model = SaprotFoldseekMutationModel(**config) |
| | tokenizer = model.tokenizer |
| | |
| | device = "cuda" |
| | model.eval() |
| | model.to(device) |
| | |
| | seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70) |
| | |
| | # Predict the effect of mutating the 3rd amino acid to A |
| | mut_info = "V3A" |
| | mut_value = model.predict_mut(seq, mut_info) |
| | print(mut_value) |
| | |
| | # Predict mutational effect of combinatorial mutations, e.g. mutating the 3rd amino acid to A and the 4th amino acid to M |
| | mut_info = "V3A:Q4M" |
| | mut_value = model.predict_mut(seq, mut_info) |
| | print(mut_value) |
| | |
| | # Predict all effects of mutations at 3rd position |
| | mut_pos = 3 |
| | mut_dict = model.predict_pos_mut(seq, mut_pos) |
| | print(mut_dict) |
| | |
| | # Predict probabilities of all amino acids at 3rd position |
| | mut_pos = 3 |
| | mut_dict = model.predict_pos_prob(seq, mut_pos) |
| | print(mut_dict) |
| | ``` |
| |
|
| | ## Get protein embeddings |
| | If you want to generate protein embeddings, you could refer to the following code. The embeddings are the average of |
| | the hidden states of the last layer. <font color=red>Note frozen SaProt supports SA sequence embeddings but not AA sequence embeddings.</font> |
| | ```python |
| | from model.saprot.base import SaprotBaseModel |
| | from transformers import EsmTokenizer |
| | |
| | |
| | config = { |
| | "task": "base", |
| | "config_path": "/your/path/to/SaProt_650M_AF2", # Note this is the directory path of SaProt, not the ".pt" file |
| | "load_pretrained": True, |
| | } |
| | |
| | model = SaprotBaseModel(**config) |
| | tokenizer = EsmTokenizer.from_pretrained(config["config_path"]) |
| | |
| | device = "cuda" |
| | model.to(device) |
| | |
| | seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70) |
| | tokens = tokenizer.tokenize(seq) |
| | print(tokens) |
| | |
| | inputs = tokenizer(seq, return_tensors="pt") |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | |
| | embeddings = model.get_hidden_states(inputs, reduction="mean") |
| | print(embeddings[0].shape) |
| | ``` |