lhallee commited on
Commit
9dbfee8
·
verified ·
1 Parent(s): 62b33b6

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +174 -174
README.md CHANGED
@@ -1,174 +1,174 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
-
6
- # NOTE
7
- The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8
-
9
- # FastESM
10
- FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11
-
12
- Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
-
14
- The default attention backend is `sdpa`. See the [FastPLMs README](https://github.com/Synthyra/FastPLMs) for a full breakdown of available backends (`sdpa`, `kernels_flash`, `flex`, `auto`) and how to switch between them. Attention maps (`output_attentions=True`) are supported on all backends via a separate naive computation.
15
- Various other optimizations also make the base implementation slightly different than the one in transformers.
16
-
17
- # FastESM2-650
18
-
19
- ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
20
- To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
21
-
22
- ## Use with 🤗 transformers
23
-
24
- ### For working with embeddings
25
- ```python
26
- import torch
27
- from transformers import AutoModel, AutoTokenizer
28
-
29
- model_path = 'Synthyra/FastESM2_650'
30
- model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
31
- tokenizer = model.tokenizer
32
-
33
- sequences = ['MPRTEIN', 'MSEQWENCE']
34
- tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
35
- with torch.no_grad():
36
- embeddings = model(**tokenized).last_hidden_state
37
-
38
- print(embeddings.shape) # (2, 11, 1280)
39
- ```
40
-
41
- ### For working with sequence logits
42
- ```python
43
- import torch
44
- from transformers import AutoModelForMaskedLM, AutoTokenizer
45
-
46
- model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
47
- with torch.no_grad():
48
- logits = model(**tokenized).logits
49
-
50
- print(logits.shape) # (2, 11, 33)
51
- ```
52
-
53
- ### For working with attention maps
54
- ```python
55
- import torch
56
- from transformers import AutoModel, AutoTokenizer
57
-
58
- model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
59
- with torch.no_grad():
60
- attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
61
-
62
- print(attentions[-1].shape) # (2, 20, 11, 11)
63
- ```
64
-
65
- ## Embed entire datasets with no new code
66
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
67
-
68
- Example:
69
- ```python
70
- embedding_dict = model.embed_dataset(
71
- sequences=[
72
- 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
73
- ],
74
- tokenizer=model.tokenizer,
75
- batch_size=2, # adjust for your GPU memory
76
- max_len=512, # adjust for your needs
77
- full_embeddings=False, # if True, no pooling is performed
78
- embed_dtype=torch.float32, # cast to what dtype you want
79
- pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
80
- num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
81
- sql=False, # if True, embeddings will be stored in SQLite database
82
- sql_db_path='embeddings.db',
83
- save=True, # if True, embeddings will be saved as a .pth file
84
- save_path='embeddings.pth',
85
- )
86
- # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
87
- ```
88
-
89
- ```
90
- model.embed_dataset()
91
- Args:
92
- sequences: List of protein sequences
93
- batch_size: Batch size for processing
94
- max_len: Maximum sequence length
95
- full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
96
- pooling_type: Type of pooling ('mean' or 'cls')
97
- num_workers: Number of workers for data loading, 0 for the main process
98
- sql: Whether to store embeddings in SQLite database - will be stored in float32
99
- sql_db_path: Path to SQLite database
100
-
101
- Returns:
102
- Dictionary mapping sequences to embeddings, or None if sql=True
103
-
104
- Note:
105
- - If sql=True, embeddings can only be stored in float32
106
- - sql is ideal if you need to stream a very large dataset for training in real-time
107
- - save=True is ideal if you can store the entire embedding dictionary in RAM
108
- - sql will be used if it is True and save is True or False
109
- - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
110
- - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
111
- ```
112
-
113
- ## Model probes
114
- We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
115
-
116
- The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
117
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/d1Xi6k1Q4-9By_MtzTvdV.png)
118
-
119
- ## Comparison of half precisions
120
- Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
121
-
122
- When summing the MSE of 1000 sequences vs. the fp32 weights:
123
-
124
- Average MSE for FP16: 0.00000140
125
-
126
- Average MSE for BF16: 0.00004125
127
-
128
- ### Inference speed
129
- We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
130
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/PvaBGfuJXEW2v_WLkt63y.png)
131
-
132
- ### Citations
133
-
134
- ```bibtex
135
- @misc{FastPLMs,
136
- author={Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
137
- title={FastPLMs: Fast, efficient, protein language model inference from Huggingface AutoModel.},
138
- year={2024},
139
- url={https://huggingface.co/Synthyra/ESMplusplus_small},
140
- DOI={10.57967/hf/3726},
141
- publisher={Hugging Face}
142
- }
143
- ```
144
-
145
- ```bibtex
146
- @article{lin2023esm2,
147
- title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
148
- author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zhongkai and Lu, Wenting and Smestad, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yaniv and dos Santos Costa, Allan and Fazel-Zarandi, Maryam and Sercu, Tom and Candido, Salvatore and Rives, Alexander},
149
- journal={Science},
150
- volume={379},
151
- number={6637},
152
- pages={1123--1130},
153
- year={2023},
154
- DOI={10.1126/science.ade2574}
155
- }
156
- ```
157
-
158
- ```bibtex
159
- @article{dong2024flexattention,
160
- title={Flex Attention: A Programming Model for Generating Optimized Attention Kernels},
161
- author={Dong, Juechu and Feng, Boyuan and Guessous, Driss and Liang, Yanbo and He, Horace},
162
- journal={arXiv preprint arXiv:2412.05496},
163
- year={2024}
164
- }
165
- ```
166
-
167
- ```bibtex
168
- @inproceedings{paszke2019pytorch,
169
- title={PyTorch: An Imperative Style, High-Performance Deep Learning Library},
170
- author={Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and Antiga, Luca and Desmaison, Alban and K{\"o}pf, Andreas and Yang, Edward and DeVito, Zach and Raison, Martin and Tejani, Alykhan and Chilamkurthy, Sasank and Steiner, Benoit and Fang, Lu and Bai, Junjie and Chintala, Soumith},
171
- booktitle={Advances in Neural Information Processing Systems 32},
172
- year={2019}
173
- }
174
- ```
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # NOTE
7
+ The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8
+
9
+ # FastESM
10
+ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11
+
12
+ Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
+
14
+ The default attention backend is `sdpa`. See the [FastPLMs README](https://github.com/Synthyra/FastPLMs) for a full breakdown of available backends (`sdpa`, `kernels_flash`, `flex`, `auto`) and how to switch between them. Attention maps (`output_attentions=True`) are supported on all backends via a separate naive computation.
15
+ Various other optimizations also make the base implementation slightly different than the one in transformers.
16
+
17
+ # FastESM2-650
18
+
19
+ ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
20
+ To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
21
+
22
+ ## Use with 🤗 transformers
23
+
24
+ ### For working with embeddings
25
+ ```python
26
+ import torch
27
+ from transformers import AutoModel, AutoTokenizer
28
+
29
+ model_path = 'Synthyra/FastESM2_650'
30
+ model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
31
+ tokenizer = model.tokenizer
32
+
33
+ sequences = ['MPRTEIN', 'MSEQWENCE']
34
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
35
+ with torch.no_grad():
36
+ embeddings = model(**tokenized).last_hidden_state
37
+
38
+ print(embeddings.shape) # (2, 11, 1280)
39
+ ```
40
+
41
+ ### For working with sequence logits
42
+ ```python
43
+ import torch
44
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
45
+
46
+ model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
47
+ with torch.no_grad():
48
+ logits = model(**tokenized).logits
49
+
50
+ print(logits.shape) # (2, 11, 33)
51
+ ```
52
+
53
+ ### For working with attention maps
54
+ ```python
55
+ import torch
56
+ from transformers import AutoModel, AutoTokenizer
57
+
58
+ model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
59
+ with torch.no_grad():
60
+ attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
61
+
62
+ print(attentions[-1].shape) # (2, 20, 11, 11)
63
+ ```
64
+
65
+ ## Embed entire datasets with no new code
66
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
67
+
68
+ Example:
69
+ ```python
70
+ embedding_dict = model.embed_dataset(
71
+ sequences=[
72
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
73
+ ],
74
+ tokenizer=model.tokenizer,
75
+ batch_size=2, # adjust for your GPU memory
76
+ max_len=512, # adjust for your needs
77
+ full_embeddings=False, # if True, no pooling is performed
78
+ embed_dtype=torch.float32, # cast to what dtype you want
79
+ pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
80
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
81
+ sql=False, # if True, embeddings will be stored in SQLite database
82
+ sql_db_path='embeddings.db',
83
+ save=True, # if True, embeddings will be saved as a .pth file
84
+ save_path='embeddings.pth',
85
+ )
86
+ # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
87
+ ```
88
+
89
+ ```
90
+ model.embed_dataset()
91
+ Args:
92
+ sequences: List of protein sequences
93
+ batch_size: Batch size for processing
94
+ max_len: Maximum sequence length
95
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
96
+ pooling_type: Type of pooling ('mean' or 'cls')
97
+ num_workers: Number of workers for data loading, 0 for the main process
98
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
99
+ sql_db_path: Path to SQLite database
100
+
101
+ Returns:
102
+ Dictionary mapping sequences to embeddings, or None if sql=True
103
+
104
+ Note:
105
+ - If sql=True, embeddings can only be stored in float32
106
+ - sql is ideal if you need to stream a very large dataset for training in real-time
107
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
108
+ - sql will be used if it is True and save is True or False
109
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
110
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
111
+ ```
112
+
113
+ ## Model probes
114
+ We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
115
+
116
+ The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
117
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/d1Xi6k1Q4-9By_MtzTvdV.png)
118
+
119
+ ## Comparison of half precisions
120
+ Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
121
+
122
+ When summing the MSE of 1000 sequences vs. the fp32 weights:
123
+
124
+ Average MSE for FP16: 0.00000140
125
+
126
+ Average MSE for BF16: 0.00004125
127
+
128
+ ### Inference speed
129
+ We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
130
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/PvaBGfuJXEW2v_WLkt63y.png)
131
+
132
+ ### Citations
133
+
134
+ ```bibtex
135
+ @misc{FastPLMs,
136
+ author={Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
137
+ title={FastPLMs: Fast, efficient, protein language model inference from Huggingface AutoModel.},
138
+ year={2024},
139
+ url={https://huggingface.co/Synthyra/ESMplusplus_small},
140
+ DOI={10.57967/hf/3726},
141
+ publisher={Hugging Face}
142
+ }
143
+ ```
144
+
145
+ ```bibtex
146
+ @article{lin2023esm2,
147
+ title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
148
+ author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zhongkai and Lu, Wenting and Smestad, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yaniv and dos Santos Costa, Allan and Fazel-Zarandi, Maryam and Sercu, Tom and Candido, Salvatore and Rives, Alexander},
149
+ journal={Science},
150
+ volume={379},
151
+ number={6637},
152
+ pages={1123--1130},
153
+ year={2023},
154
+ DOI={10.1126/science.ade2574}
155
+ }
156
+ ```
157
+
158
+ ```bibtex
159
+ @article{dong2024flexattention,
160
+ title={Flex Attention: A Programming Model for Generating Optimized Attention Kernels},
161
+ author={Dong, Juechu and Feng, Boyuan and Guessous, Driss and Liang, Yanbo and He, Horace},
162
+ journal={arXiv preprint arXiv:2412.05496},
163
+ year={2024}
164
+ }
165
+ ```
166
+
167
+ ```bibtex
168
+ @inproceedings{paszke2019pytorch,
169
+ title={PyTorch: An Imperative Style, High-Performance Deep Learning Library},
170
+ author={Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and Antiga, Luca and Desmaison, Alban and K{\"o}pf, Andreas and Yang, Edward and DeVito, Zach and Raison, Martin and Tejani, Alykhan and Chilamkurthy, Sasank and Steiner, Benoit and Fang, Lu and Bai, Junjie and Chintala, Soumith},
171
+ booktitle={Advances in Neural Information Processing Systems 32},
172
+ year={2019}
173
+ }
174
+ ```