--- language: - en --- # VESM: Co-distillation of ESM models for Variant Effect Prediction This repository contains the VESM protein language models developed in the paper ["VESM: Compressing the collective knowledge of ESM into a single protein language model"](vesm_arxiv) by Tuan Dinh, Seon-Kyeong Jang, Noah Zaitlen and Vasilis Ntranos. --- ## Quick start A simple way to get started is to run our notebook directly on a Google Colab instance: [![Getting Started with VESM](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ntranoslab/vesm/blob/main/notebooks/VESM_Getting_Started.ipynb) See also https://github.com/ntranoslab/vesm ## Download models **Using python** ```py from huggingface_hub import snapshot_download, hf_hub_download local_dir = './vesm' # Download each model model_offset = 0 model_name = ["VESM_35M", "VESM_150M", "VESM_650M", "VESM_3B", "VESM3"][model_offset] hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir) # Download all models snapshot_download(repo_id="ntranoslab/vesm", local_dir=local_dir) ``` **Using huggingface CLI** ```bash huggingface-cli download ntranoslab/vesm --local-dir local_dir ``` --- ## Usage We provide a simple usage of our models for predicting variant effects. **Loading helpers** ```py import torch from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, EsmForMaskedLM esm_dict = { "VESM_35M": 'facebook/esm2_t12_35M_UR50D', "VESM_150M": 'facebook/esm2_t30_150M_UR50D', "VESM_650M": 'facebook/esm2_t33_650M_UR50D', "VESM_3B": 'facebook/esm2_t36_3B_UR50D', "VESM3": "esm3_sm_open_v1" } def load_vesm(model_name="VESM_3B", local_dir="vesm", device='cuda'): if model_name in esm_dict: ckt = esm_dict[model_name] else: print("Model not found") return None # download weights hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir) # load base model if model_name == "VESM3": from esm.models.esm3 import ESM3 model = ESM3.from_pretrained(ckt, device=device).to(torch.float) tokenizer = model.tokenizers.sequence else: model = EsmForMaskedLM.from_pretrained(ckt).to(device) tokenizer = AutoTokenizer.from_pretrained(ckt) # load pretrained VESM model.load_state_dict(torch.load(f'{local_dir}/{model_name}.pth'), strict=False) return model, tokenizer ``` **Variant Effect Prediction** ```py # scoring functions import torch.nn.functional as F # calculate log-likelihood ratio from the logits def get_llrs(sequence_logits, input_ids): token_probs = torch.log_softmax(sequence_logits, dim=-1) wt_positions = F.one_hot(input_ids, num_classes=token_probs.shape[-1]) wt_probs = token_probs * wt_positions wt_probs = wt_probs.sum(dim=-1, keepdim=True) # add alpha llrs = token_probs - wt_probs.expand(token_probs.shape) return llrs # compute mutation score def score_mutation(llrs, mutation, sequence_vocabs): mutation_score = 0 for mut in mutation.split(":"): _, idx, mt = mut[0], int(mut[1:-1]), mut[-1] pred = llrs[idx, sequence_vocabs[mt]] mutation_score += pred.item() return mutation_score ``` #### Sequence-only Models Here, we provide sample scripts to compute mutation scores. ```py # sequence and mutation sequence = "MVNSTHRGMHTSLHLWNRSSYRLHSNASESLGKGYSDGGCYEQLFVSPEVFVTLGVISLLENILV" mutation = "M1Y:V2T" ``` ```py # Setting local_dir = 'vesm' gpu_id = 0 device = torch.device(f'cuda:{gpu_id}') if torch.cuda.is_available() else 'cpu' # Helper def inference(model, tokenizer, sequence, device): tokens = tokenizer([sequence], return_tensors='pt').to(device) with torch.no_grad(): outputs = model(**tokens) logits = outputs['logits'][0] input_ids = tokens['input_ids'][0] # calculate log-likelihood ratio from the logits llrs = get_llrs(logits, input_ids) return llrs # Prediction with VESM model_name = 'VESM_3B' model, tokenizer = load_vesm(model_name, local_dir=local_dir, device=device) sequence_vocabs = tokenizer.get_vocab() # compute mutation score llrs = inference(model, tokenizer, sequence, device) mutation_score = score_mutation(llrs, mutation, sequence_vocabs) print(f"Predicted score by {model_name}: ", mutation_score) ``` #### Using Structure with VESM3 ```py from esm.sdk.api import ESMProtein # A sample structure pdb: download the latest version # !wget https://alphafold.ebi.ac.uk/files/AF-P32245-F1-model_v6.pdb pdb_file = "AF-P32245-F1-model_v6.pdb" protein = ESMProtein.from_pdb(pdb_file) mutation = "M1Y:V2T" ``` ```py # load model model, tokenizer = load_vesm('VESM3', local_dir=local_dir, device=device) sequence_vocabs = tokenizer.get_vocab() # inference tokens = model.encode(protein) seq_tokens = tokens.sequence.reshape(1,-1) struct_tokens = tokens.structure.reshape(1,-1) with torch.no_grad(): outs = model.forward(sequence_tokens=seq_tokens, structure_tokens=struct_tokens) logits = outs.sequence_logits[0, :, :] input_ids = tokens.sequence # calculate log-likelihood ratio from the logits llrs = get_llrs(logits, input_ids) # compute mutation score mutation_score = score_mutation(llrs, mutation, sequence_vocabs) print("mutation score: ", mutation_score) ``` ## License The source code and model weights for VESM models are distributed under the MIT License. The VESM3 model is a fine-tuned version of ESM3-Open (EvolutionaryScale) and is available under a [non-commercial license agreement](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement).