lhallee commited on
Commit
3671b3e
·
verified ·
1 Parent(s): d8890e1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +119 -111
README.md CHANGED
@@ -1,112 +1,120 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
-
6
- # FastESM
7
- FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
8
-
9
- Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
10
-
11
- Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
- Various other optimizations also make the base implementation slightly different than the one in transformers.
13
-
14
- ## Use with 🤗 transformers
15
-
16
- ### Supported models
17
- ```python
18
- model_dict = {
19
- # Synthyra/ESM2-8M
20
- 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
21
- # Synthyra/ESM2-35M
22
- 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
23
- # Synthyra/ESM2-150M
24
- 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
25
- # Synthyra/ESM2-650M
26
- 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
27
- # Synthyra/ESM2-3B
28
- 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
29
- }
30
- ```
31
-
32
- ### For working with embeddings
33
- ```python
34
- import torch
35
- from transformers import AutoModel, AutoTokenizer
36
-
37
- model_path = 'Synthyra/ESM2-8M'
38
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
39
- tokenizer = model.tokenizer
40
-
41
- sequences = ['MPRTEIN', 'MSEQWENCE']
42
- tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
43
- with torch.no_grad():
44
- embeddings = model(**tokenized).last_hidden_state
45
-
46
- print(embeddings.shape) # (2, 11, 1280)
47
- ```
48
-
49
- ### For working with sequence logits
50
- ```python
51
- import torch
52
- from transformers import AutoModelForMaskedLM, AutoTokenizer
53
-
54
- model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
55
- with torch.no_grad():
56
- logits = model(**tokenized).logits
57
-
58
- print(logits.shape) # (2, 11, 33)
59
- ```
60
-
61
- ### For working with attention maps
62
- ```python
63
- import torch
64
- from transformers import AutoModel, AutoTokenizer
65
-
66
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
67
- with torch.no_grad():
68
- attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
69
-
70
- print(attentions[-1].shape) # (2, 20, 11, 11)
71
- ```
72
-
73
- ## Embed entire datasets with no new code
74
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
75
- ```python
76
- embeddings = model.embed_dataset(
77
- sequences=sequences, # list of protein strings
78
- batch_size=16, # embedding batch size
79
- max_len=2048, # truncate to max_len
80
- full_embeddings=True, # return residue-wise embeddings
81
- full_precision=False, # store as float32
82
- pooling_type='mean', # use mean pooling if protein-wise embeddings
83
- num_workers=0, # data loading num workers
84
- sql=False, # return dictionary of sequences and embeddings
85
- )
86
-
87
- _ = model.embed_dataset(
88
- sequences=sequences, # list of protein strings
89
- batch_size=16, # embedding batch size
90
- max_len=2048, # truncate to max_len
91
- full_embeddings=True, # return residue-wise embeddings
92
- full_precision=False, # store as float32
93
- pooling_type='mean', # use mean pooling if protein-wise embeddings
94
- num_workers=0, # data loading num workers
95
- sql=True, # store sequences in local SQL database
96
- sql_db_path='embeddings.db', # path to .db file of choice
97
- )
98
- ```
99
-
100
-
101
- ### Citation
102
- If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
103
- ```
104
- @misc {FastESM2,
105
- author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
106
- title = { FastESM2 },
107
- year = 2024,
108
- url = { https://huggingface.co/Synthyra/FastESM2_650 },
109
- doi = { 10.57967/hf/3729 },
110
- publisher = { Hugging Face }
111
- }
 
 
 
 
 
 
 
 
112
  ```
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # FastESM
7
+ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
8
+
9
+ Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
10
+
11
+ Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
+ Various other optimizations also make the base implementation slightly different than the one in transformers.
13
+
14
+ ## Use with 🤗 transformers
15
+
16
+ ### Supported models
17
+ ```python
18
+ model_dict = {
19
+ # Synthyra/ESM2-8M
20
+ 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
21
+ # Synthyra/ESM2-35M
22
+ 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
23
+ # Synthyra/ESM2-150M
24
+ 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
25
+ # Synthyra/ESM2-650M
26
+ 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
27
+ # Synthyra/ESM2-3B
28
+ 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
29
+ }
30
+ ```
31
+
32
+ ### For working with embeddings
33
+ ```python
34
+ import torch
35
+ from transformers import AutoModel, AutoTokenizer
36
+
37
+ model_path = 'Synthyra/ESM2-8M'
38
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
39
+ tokenizer = model.tokenizer
40
+
41
+ sequences = ['MPRTEIN', 'MSEQWENCE']
42
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
43
+ with torch.no_grad():
44
+ embeddings = model(**tokenized).last_hidden_state
45
+
46
+ print(embeddings.shape) # (2, 11, 1280)
47
+ ```
48
+
49
+ ### For working with sequence logits
50
+ ```python
51
+ import torch
52
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
53
+
54
+ model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
55
+ with torch.no_grad():
56
+ logits = model(**tokenized).logits
57
+
58
+ print(logits.shape) # (2, 11, 33)
59
+ ```
60
+
61
+ ### For working with attention maps
62
+ ```python
63
+ import torch
64
+ from transformers import AutoModel, AutoTokenizer
65
+
66
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
67
+ with torch.no_grad():
68
+ attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
69
+
70
+ print(attentions[-1].shape) # (2, 20, 11, 11)
71
+ ```
72
+
73
+ ### Contact prediction
74
+ Because we can output attentions using the naive attention implementation, the contact prediction is also supported
75
+ ```python
76
+ with torch.no_grad():
77
+ contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
78
+ ```
79
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
80
+
81
+ ## Embed entire datasets with no new code
82
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
83
+ ```python
84
+ embeddings = model.embed_dataset(
85
+ sequences=sequences, # list of protein strings
86
+ batch_size=16, # embedding batch size
87
+ max_len=2048, # truncate to max_len
88
+ full_embeddings=True, # return residue-wise embeddings
89
+ full_precision=False, # store as float32
90
+ pooling_type='mean', # use mean pooling if protein-wise embeddings
91
+ num_workers=0, # data loading num workers
92
+ sql=False, # return dictionary of sequences and embeddings
93
+ )
94
+
95
+ _ = model.embed_dataset(
96
+ sequences=sequences, # list of protein strings
97
+ batch_size=16, # embedding batch size
98
+ max_len=2048, # truncate to max_len
99
+ full_embeddings=True, # return residue-wise embeddings
100
+ full_precision=False, # store as float32
101
+ pooling_type='mean', # use mean pooling if protein-wise embeddings
102
+ num_workers=0, # data loading num workers
103
+ sql=True, # store sequences in local SQL database
104
+ sql_db_path='embeddings.db', # path to .db file of choice
105
+ )
106
+ ```
107
+
108
+
109
+ ### Citation
110
+ If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
111
+ ```
112
+ @misc {FastESM2,
113
+ author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
114
+ title = { FastESM2 },
115
+ year = 2024,
116
+ url = { https://huggingface.co/Synthyra/FastESM2_650 },
117
+ doi = { 10.57967/hf/3729 },
118
+ publisher = { Hugging Face }
119
+ }
120
  ```