VESM: Co-distillation of ESM models for Variant Effect Prediction
This repository contains the VESM protein language models developed in the paper "VESM: Compressing the collective knowledge of ESM into a single protein language model" by Tuan Dinh, Seon-Kyeong Jang, Noah Zaitlen and Vasilis Ntranos.
Quick start
A simple way to get started is to run our notebook directly on a Google Colab instance:
See also https://github.com/ntranoslab/vesm
Download models
Using python
from huggingface_hub import snapshot_download, hf_hub_download
local_dir = './vesm'
# Download each model
model_offset = 0
model_name = ["VESM_35M", "VESM_150M", "VESM_650M", "VESM_3B", "VESM3"][model_offset]
hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir)
# Download all models
snapshot_download(repo_id="ntranoslab/vesm", local_dir=local_dir)
Using huggingface CLI
huggingface-cli download ntranoslab/vesm --local-dir local_dir
Usage
We provide a simple usage of our models for predicting variant effects.
Loading helpers
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, EsmForMaskedLM
esm_dict = {
"VESM_35M": 'facebook/esm2_t12_35M_UR50D',
"VESM_150M": 'facebook/esm2_t30_150M_UR50D',
"VESM_650M": 'facebook/esm2_t33_650M_UR50D',
"VESM_3B": 'facebook/esm2_t36_3B_UR50D',
"VESM3": "esm3_sm_open_v1"
}
def load_vesm(model_name="VESM_3B", local_dir="vesm", device='cuda'):
if model_name in esm_dict:
ckt = esm_dict[model_name]
else:
print("Model not found")
return None
# download weights
hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir)
# load base model
if model_name == "VESM3":
from esm.models.esm3 import ESM3
model = ESM3.from_pretrained(ckt, device=device).to(torch.float)
tokenizer = model.tokenizers.sequence
else:
model = EsmForMaskedLM.from_pretrained(ckt).to(device)
tokenizer = AutoTokenizer.from_pretrained(ckt)
# load pretrained VESM
model.load_state_dict(torch.load(f'{local_dir}/{model_name}.pth'), strict=False)
return model, tokenizer
Variant Effect Prediction
# scoring functions
import torch.nn.functional as F
# calculate log-likelihood ratio from the logits
def get_llrs(sequence_logits, input_ids):
token_probs = torch.log_softmax(sequence_logits, dim=-1)
wt_positions = F.one_hot(input_ids, num_classes=token_probs.shape[-1])
wt_probs = token_probs * wt_positions
wt_probs = wt_probs.sum(dim=-1, keepdim=True)
# add alpha
llrs = token_probs - wt_probs.expand(token_probs.shape)
return llrs
# compute mutation score
def score_mutation(llrs, mutation, sequence_vocabs):
mutation_score = 0
for mut in mutation.split(":"):
_, idx, mt = mut[0], int(mut[1:-1]), mut[-1]
pred = llrs[idx, sequence_vocabs[mt]]
mutation_score += pred.item()
return mutation_score
Sequence-only Models
Here, we provide sample scripts to compute mutation scores.
# sequence and mutation
sequence = "MVNSTHRGMHTSLHLWNRSSYRLHSNASESLGKGYSDGGCYEQLFVSPEVFVTLGVISLLENILV"
mutation = "M1Y:V2T"
# Setting
local_dir = 'vesm'
gpu_id = 0
device = torch.device(f'cuda:{gpu_id}') if torch.cuda.is_available() else 'cpu'
# Helper
def inference(model, tokenizer, sequence, device):
tokens = tokenizer([sequence], return_tensors='pt').to(device)
with torch.no_grad():
outputs = model(**tokens)
logits = outputs['logits'][0]
input_ids = tokens['input_ids'][0]
# calculate log-likelihood ratio from the logits
llrs = get_llrs(logits, input_ids)
return llrs
# Prediction with VESM
model_name = 'VESM_3B'
model, tokenizer = load_vesm(model_name, local_dir=local_dir, device=device)
sequence_vocabs = tokenizer.get_vocab()
# compute mutation score
llrs = inference(model, tokenizer, sequence, device)
mutation_score = score_mutation(llrs, mutation, sequence_vocabs)
print(f"Predicted score by {model_name}: ", mutation_score)
Using Structure with VESM3
from esm.sdk.api import ESMProtein
# A sample structure pdb: download the latest version
# !wget https://alphafold.ebi.ac.uk/files/AF-P32245-F1-model_v6.pdb
pdb_file = "AF-P32245-F1-model_v6.pdb"
protein = ESMProtein.from_pdb(pdb_file)
mutation = "M1Y:V2T"
# load model
model, tokenizer = load_vesm('VESM3', local_dir=local_dir, device=device)
sequence_vocabs = tokenizer.get_vocab()
# inference
tokens = model.encode(protein)
seq_tokens = tokens.sequence.reshape(1,-1)
struct_tokens = tokens.structure.reshape(1,-1)
with torch.no_grad():
outs = model.forward(sequence_tokens=seq_tokens, structure_tokens=struct_tokens)
logits = outs.sequence_logits[0, :, :]
input_ids = tokens.sequence
# calculate log-likelihood ratio from the logits
llrs = get_llrs(logits, input_ids)
# compute mutation score
mutation_score = score_mutation(llrs, mutation, sequence_vocabs)
print("mutation score: ", mutation_score)
License
The source code and model weights for VESM models are distributed under the MIT License. The VESM3 model is a fine-tuned version of ESM3-Open (EvolutionaryScale) and is available under a non-commercial license agreement.