|
|
--- |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
|
|
|
# VESM: Co-distillation of ESM models for Variant Effect Prediction |
|
|
|
|
|
This repository contains the VESM protein language models developed in the paper ["VESM: Compressing the collective knowledge of ESM into a single protein language model"](vesm_arxiv) by Tuan Dinh, Seon-Kyeong Jang, Noah Zaitlen and Vasilis Ntranos. |
|
|
|
|
|
--- |
|
|
## Quick start <a name="quickstart"></a> |
|
|
A simple way to get started is to run our notebook directly on a Google Colab instance: |
|
|
[](https://colab.research.google.com/github/ntranoslab/vesm/blob/main/notebooks/VESM_Getting_Started.ipynb) |
|
|
|
|
|
|
|
|
See also https://github.com/ntranoslab/vesm |
|
|
## Download models |
|
|
|
|
|
**Using python** |
|
|
```py |
|
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
|
|
|
local_dir = './vesm' |
|
|
|
|
|
# Download each model |
|
|
model_offset = 0 |
|
|
model_name = ["VESM_35M", "VESM_150M", "VESM_650M", "VESM_3B", "VESM3"][model_offset] |
|
|
hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir) |
|
|
|
|
|
# Download all models |
|
|
snapshot_download(repo_id="ntranoslab/vesm", local_dir=local_dir) |
|
|
``` |
|
|
|
|
|
**Using huggingface CLI** |
|
|
```bash |
|
|
huggingface-cli download ntranoslab/vesm --local-dir local_dir |
|
|
``` |
|
|
|
|
|
--- |
|
|
## Usage <a name="usage"></a> |
|
|
|
|
|
We provide a simple usage of our models for predicting variant effects. |
|
|
|
|
|
**Loading helpers** |
|
|
```py |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import AutoTokenizer, EsmForMaskedLM |
|
|
|
|
|
esm_dict = { |
|
|
"VESM_35M": 'facebook/esm2_t12_35M_UR50D', |
|
|
"VESM_150M": 'facebook/esm2_t30_150M_UR50D', |
|
|
"VESM_650M": 'facebook/esm2_t33_650M_UR50D', |
|
|
"VESM_3B": 'facebook/esm2_t36_3B_UR50D', |
|
|
"VESM3": "esm3_sm_open_v1" |
|
|
} |
|
|
def load_vesm(model_name="VESM_3B", local_dir="vesm", device='cuda'): |
|
|
if model_name in esm_dict: |
|
|
ckt = esm_dict[model_name] |
|
|
else: |
|
|
print("Model not found") |
|
|
return None |
|
|
|
|
|
# download weights |
|
|
hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir) |
|
|
# load base model |
|
|
if model_name == "VESM3": |
|
|
from esm.models.esm3 import ESM3 |
|
|
model = ESM3.from_pretrained(ckt, device=device).to(torch.float) |
|
|
tokenizer = model.tokenizers.sequence |
|
|
else: |
|
|
model = EsmForMaskedLM.from_pretrained(ckt).to(device) |
|
|
tokenizer = AutoTokenizer.from_pretrained(ckt) |
|
|
# load pretrained VESM |
|
|
model.load_state_dict(torch.load(f'{local_dir}/{model_name}.pth'), strict=False) |
|
|
return model, tokenizer |
|
|
``` |
|
|
|
|
|
**Variant Effect Prediction** |
|
|
|
|
|
```py |
|
|
# scoring functions |
|
|
import torch.nn.functional as F |
|
|
# calculate log-likelihood ratio from the logits |
|
|
def get_llrs(sequence_logits, input_ids): |
|
|
token_probs = torch.log_softmax(sequence_logits, dim=-1) |
|
|
wt_positions = F.one_hot(input_ids, num_classes=token_probs.shape[-1]) |
|
|
wt_probs = token_probs * wt_positions |
|
|
wt_probs = wt_probs.sum(dim=-1, keepdim=True) |
|
|
# add alpha |
|
|
llrs = token_probs - wt_probs.expand(token_probs.shape) |
|
|
return llrs |
|
|
|
|
|
# compute mutation score |
|
|
def score_mutation(llrs, mutation, sequence_vocabs): |
|
|
mutation_score = 0 |
|
|
for mut in mutation.split(":"): |
|
|
_, idx, mt = mut[0], int(mut[1:-1]), mut[-1] |
|
|
pred = llrs[idx, sequence_vocabs[mt]] |
|
|
mutation_score += pred.item() |
|
|
return mutation_score |
|
|
``` |
|
|
|
|
|
#### Sequence-only Models |
|
|
|
|
|
Here, we provide sample scripts to compute mutation scores. |
|
|
```py |
|
|
# sequence and mutation |
|
|
sequence = "MVNSTHRGMHTSLHLWNRSSYRLHSNASESLGKGYSDGGCYEQLFVSPEVFVTLGVISLLENILV" |
|
|
mutation = "M1Y:V2T" |
|
|
``` |
|
|
|
|
|
```py |
|
|
# Setting |
|
|
local_dir = 'vesm' |
|
|
gpu_id = 0 |
|
|
device = torch.device(f'cuda:{gpu_id}') if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
# Helper |
|
|
def inference(model, tokenizer, sequence, device): |
|
|
tokens = tokenizer([sequence], return_tensors='pt').to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**tokens) |
|
|
logits = outputs['logits'][0] |
|
|
input_ids = tokens['input_ids'][0] |
|
|
# calculate log-likelihood ratio from the logits |
|
|
llrs = get_llrs(logits, input_ids) |
|
|
return llrs |
|
|
|
|
|
# Prediction with VESM |
|
|
model_name = 'VESM_3B' |
|
|
model, tokenizer = load_vesm(model_name, local_dir=local_dir, device=device) |
|
|
sequence_vocabs = tokenizer.get_vocab() |
|
|
# compute mutation score |
|
|
llrs = inference(model, tokenizer, sequence, device) |
|
|
mutation_score = score_mutation(llrs, mutation, sequence_vocabs) |
|
|
print(f"Predicted score by {model_name}: ", mutation_score) |
|
|
``` |
|
|
|
|
|
|
|
|
#### Using Structure with VESM3 |
|
|
```py |
|
|
from esm.sdk.api import ESMProtein |
|
|
|
|
|
# A sample structure pdb: download the latest version |
|
|
# !wget https://alphafold.ebi.ac.uk/files/AF-P32245-F1-model_v6.pdb |
|
|
pdb_file = "AF-P32245-F1-model_v6.pdb" |
|
|
protein = ESMProtein.from_pdb(pdb_file) |
|
|
mutation = "M1Y:V2T" |
|
|
``` |
|
|
|
|
|
```py |
|
|
# load model |
|
|
model, tokenizer = load_vesm('VESM3', local_dir=local_dir, device=device) |
|
|
sequence_vocabs = tokenizer.get_vocab() |
|
|
|
|
|
# inference |
|
|
tokens = model.encode(protein) |
|
|
seq_tokens = tokens.sequence.reshape(1,-1) |
|
|
struct_tokens = tokens.structure.reshape(1,-1) |
|
|
with torch.no_grad(): |
|
|
outs = model.forward(sequence_tokens=seq_tokens, structure_tokens=struct_tokens) |
|
|
logits = outs.sequence_logits[0, :, :] |
|
|
input_ids = tokens.sequence |
|
|
|
|
|
# calculate log-likelihood ratio from the logits |
|
|
llrs = get_llrs(logits, input_ids) |
|
|
# compute mutation score |
|
|
mutation_score = score_mutation(llrs, mutation, sequence_vocabs) |
|
|
print("mutation score: ", mutation_score) |
|
|
``` |
|
|
|
|
|
|
|
|
## License <a name="license"></a> |
|
|
|
|
|
The source code and model weights for VESM models are distributed under the MIT License. |
|
|
The VESM3 model is a fine-tuned version of ESM3-Open (EvolutionaryScale) and is available under a [non-commercial license agreement](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement). |
|
|
|