File size: 5,749 Bytes
e1b39a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c350f8f
e1b39a7
 
 
 
 
 
 
 
 
 
 
c350f8f
e1b39a7
 
 
 
 
 
 
 
 
 
 
 
 
 
946749e
e1b39a7
 
 
 
 
 
 
 
 
946749e
 
 
 
e1b39a7
 
946749e
 
e1b39a7
 
 
 
946749e
e1b39a7
946749e
e1b39a7
946749e
e1b39a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
946749e
e1b39a7
 
 
946749e
c350f8f
e1b39a7
 
946749e
e1b39a7
946749e
 
e1b39a7
 
 
 
 
 
 
c350f8f
 
 
e1b39a7
946749e
e1b39a7
 
 
 
 
946749e
e1b39a7
 
 
 
 
 
 
 
 
 
c350f8f
e1b39a7
946749e
 
 
e1b39a7
 
 
 
 
946749e
e1b39a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
---
language:
- en
---

# VESM: Co-distillation of ESM models for Variant Effect Prediction

This repository contains the VESM protein language models developed in the paper ["VESM: Compressing the collective knowledge of ESM into a single protein language model"](vesm_arxiv) by Tuan Dinh, Seon-Kyeong Jang, Noah Zaitlen and Vasilis Ntranos.

---
## Quick start <a name="quickstart"></a>
A simple way to get started is to run our notebook directly on a Google Colab instance:
[![Getting Started with VESM](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ntranoslab/vesm/blob/main/notebooks/VESM_Getting_Started.ipynb) 


See also https://github.com/ntranoslab/vesm
## Download models

**Using python**
```py
from huggingface_hub import snapshot_download, hf_hub_download

local_dir = './vesm'

# Download each model
model_offset = 0
model_name = ["VESM_35M", "VESM_150M", "VESM_650M", "VESM_3B", "VESM3"][model_offset]
hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir)

# Download all models
snapshot_download(repo_id="ntranoslab/vesm", local_dir=local_dir)
```

**Using huggingface CLI**
```bash
huggingface-cli download ntranoslab/vesm --local-dir local_dir
```

---
## Usage  <a name="usage"></a>

We provide a simple usage of our models for predicting variant effects.

**Loading helpers**
```py
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, EsmForMaskedLM

esm_dict = {
    "VESM_35M": 'facebook/esm2_t12_35M_UR50D',
    "VESM_150M": 'facebook/esm2_t30_150M_UR50D', 
    "VESM_650M": 'facebook/esm2_t33_650M_UR50D', 
    "VESM_3B": 'facebook/esm2_t36_3B_UR50D', 
    "VESM3": "esm3_sm_open_v1"
}
def load_vesm(model_name="VESM_3B", local_dir="vesm", device='cuda'):
    if model_name in esm_dict:
        ckt = esm_dict[model_name]
    else:
        print("Model not found")
        return None

    # download weights
    hf_hub_download(repo_id="ntranoslab/vesm", filename=f"{model_name}.pth", local_dir=local_dir)
    # load base model
    if model_name == "VESM3":
      from esm.models.esm3 import ESM3
      model = ESM3.from_pretrained(ckt, device=device).to(torch.float)
      tokenizer = model.tokenizers.sequence
    else:
      model = EsmForMaskedLM.from_pretrained(ckt).to(device)
      tokenizer = AutoTokenizer.from_pretrained(ckt)
    # load pretrained VESM
    model.load_state_dict(torch.load(f'{local_dir}/{model_name}.pth'), strict=False)
    return model, tokenizer
```

**Variant Effect Prediction**

```py
# scoring functions
import torch.nn.functional as F
# calculate log-likelihood ratio from the logits 
def get_llrs(sequence_logits, input_ids):
    token_probs = torch.log_softmax(sequence_logits, dim=-1)
    wt_positions = F.one_hot(input_ids, num_classes=token_probs.shape[-1])
    wt_probs = token_probs * wt_positions
    wt_probs = wt_probs.sum(dim=-1, keepdim=True)
    # add alpha 
    llrs = token_probs - wt_probs.expand(token_probs.shape)
    return llrs

# compute mutation score
def score_mutation(llrs, mutation, sequence_vocabs):
    mutation_score = 0
    for mut in mutation.split(":"):
        _, idx, mt = mut[0], int(mut[1:-1]), mut[-1]
        pred = llrs[idx, sequence_vocabs[mt]] 
        mutation_score += pred.item()
    return mutation_score
```

#### Sequence-only Models 

Here, we provide sample scripts to compute mutation scores.
```py
# sequence and mutation
sequence = "MVNSTHRGMHTSLHLWNRSSYRLHSNASESLGKGYSDGGCYEQLFVSPEVFVTLGVISLLENILV"
mutation = "M1Y:V2T"
```

```py
# Setting
local_dir = 'vesm'
gpu_id = 0
device = torch.device(f'cuda:{gpu_id}') if torch.cuda.is_available() else 'cpu'

# Helper
def inference(model, tokenizer, sequence, device):
    tokens = tokenizer([sequence], return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model(**tokens)
    logits = outputs['logits'][0]
    input_ids = tokens['input_ids'][0]
    # calculate log-likelihood ratio from the logits 
    llrs = get_llrs(logits, input_ids)
    return llrs

# Prediction with VESM
model_name = 'VESM_3B'
model, tokenizer = load_vesm(model_name, local_dir=local_dir, device=device)
sequence_vocabs = tokenizer.get_vocab()
# compute mutation score
llrs = inference(model, tokenizer, sequence, device)
mutation_score = score_mutation(llrs, mutation, sequence_vocabs)
print(f"Predicted score by {model_name}: ", mutation_score)
```


#### Using Structure with VESM3
```py
from esm.sdk.api import ESMProtein

# A sample structure pdb: download the latest version
# !wget https://alphafold.ebi.ac.uk/files/AF-P32245-F1-model_v6.pdb
pdb_file = "AF-P32245-F1-model_v6.pdb"
protein = ESMProtein.from_pdb(pdb_file)
mutation = "M1Y:V2T"
```

```py
# load model
model, tokenizer = load_vesm('VESM3', local_dir=local_dir, device=device)
sequence_vocabs = tokenizer.get_vocab()

# inference
tokens = model.encode(protein)
seq_tokens = tokens.sequence.reshape(1,-1)
struct_tokens = tokens.structure.reshape(1,-1)
with torch.no_grad():
  outs = model.forward(sequence_tokens=seq_tokens, structure_tokens=struct_tokens)
  logits = outs.sequence_logits[0, :, :]
  input_ids = tokens.sequence

# calculate log-likelihood ratio from the logits 
llrs = get_llrs(logits, input_ids)
# compute mutation score
mutation_score = score_mutation(llrs, mutation, sequence_vocabs)
print("mutation score: ", mutation_score)
```


## License  <a name="license"></a>

The source code and model weights for VESM models are distributed under the MIT License. 
The VESM3 model is a fine-tuned version of ESM3-Open (EvolutionaryScale) and is available under a [non-commercial license agreement](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement).