lhallee commited on
Commit
116b160
·
verified ·
1 Parent(s): bedd3bd

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +181 -171
README.md CHANGED
@@ -1,172 +1,182 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
-
6
- # NOTE
7
- The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8
-
9
- # ESM++
10
- [ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package.
11
- The large version corresponds to the 600 million parameter version of ESMC.
12
-
13
- ## Attention backend defaults
14
- Flex Attention with a block mask that ignores pad tokens is the default attention backend. If Flex Attention is unavailable, ESM++ falls back to native PyTorch attention.
15
-
16
- For throughput and memory efficiency, `torch.compile(...)` is heavily recommended, especially when using Flex Attention.
17
-
18
-
19
- ## Use with 🤗 transformers
20
- ```python
21
- from transformers import AutoModelForMaskedLM
22
- model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
23
- tokenizer = model.tokenizer
24
-
25
- sequences = ['MPRTEIN', 'MSEQWENCE']
26
- tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
27
-
28
- # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
29
-
30
- output = model(**tokenized) # get all hidden states with output_hidden_states=True
31
- print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
32
- print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 1152)
33
- print(output.loss) # language modeling loss if you passed labels
34
- #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
35
- ```
36
-
37
- ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.
38
-
39
- ```python
40
- from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
41
-
42
- model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
43
- logits = model(**tokenized).logits
44
- print(logits.shape) # (batch_size, num_labels), (2, 2)
45
- ```
46
-
47
- ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this:
48
- ```python
49
- import torch
50
- model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True, dtype=torch.float16) # or torch.bfloat16
51
- ```
52
-
53
- ## Embed entire datasets with no new code
54
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
55
-
56
- Example:
57
- ```python
58
- embedding_dict = model.embed_dataset(
59
- sequences=[
60
- 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
61
- ],
62
- tokenizer=model.tokenizer,
63
- batch_size=2, # adjust for your GPU memory
64
- max_len=512, # adjust for your needs
65
- full_embeddings=False, # if True, no pooling is performed
66
- embed_dtype=torch.float32, # cast to what dtype you want
67
- pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
68
- num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
69
- sql=False, # if True, embeddings will be stored in SQLite database
70
- sql_db_path='embeddings.db',
71
- save=True, # if True, embeddings will be saved as a .pth file
72
- save_path='embeddings.pth',
73
- )
74
- # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
75
- ```
76
-
77
- ```
78
- model.embed_dataset()
79
- Args:
80
- sequences: List of protein sequences
81
- batch_size: Batch size for processing
82
- max_len: Maximum sequence length
83
- full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
84
- pooling_type: Type of pooling ('mean' or 'cls')
85
- num_workers: Number of workers for data loading, 0 for the main process
86
- sql: Whether to store embeddings in SQLite database - will be stored in float32
87
- sql_db_path: Path to SQLite database
88
-
89
- Returns:
90
- Dictionary mapping sequences to embeddings, or None if sql=True
91
-
92
- Note:
93
- - If sql=True, embeddings can only be stored in float32
94
- - sql is ideal if you need to stream a very large dataset for training in real-time
95
- - save=True is ideal if you can store the entire embedding dictionary in RAM
96
- - sql will be used if it is True and save is True or False
97
- - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
98
- - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
99
- ```
100
-
101
- ## Fine-tuning with 🤗 peft
102
- ```python
103
- model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
104
- # these modules handle ESM++ and ESM2 attention layers
105
- target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
106
-
107
- lora_config = LoraConfig(
108
- r=8, # choose lora parameters to your liking
109
- lora_alpha=16,
110
- lora_dropout=0.01,
111
- bias="none",
112
- target_modules=target_modules,
113
- )
114
-
115
- # Apply LoRA to the model
116
- model = get_peft_model(model, lora_config)
117
-
118
- # Unfreeze the classifier head
119
- for param in model.classifier.parameters():
120
- param.requires_grad = True
121
- ```
122
-
123
- For a more thourough example of fine-tuning, check out our example script [here](https://github.com/Synthyra/FastPLMs/blob/main/fine_tuning_example.py).
124
-
125
-
126
- ## Returning attention maps
127
- Flex Attention with a pad-token block mask is used by default for attention calculations, and native PyTorch attention is the fallback. Optimized attention paths do not return attention maps directly.
128
- ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps.
129
-
130
- ```python
131
- output = model(**tokenized, output_attentions=True)
132
- att = output.attentions
133
- len(att) # 33, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
134
- ```
135
-
136
- ## Comparison across floating-point precision and implementations
137
- We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
138
- Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision.
139
-
140
- Average MSE for FP16: 0.00000003
141
-
142
- Average MSE for BF16: 0.00000122
143
-
144
- We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package.
145
-
146
- Average MSE of last hidden state: 2.46e-09
147
-
148
- You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_600m')
149
-
150
- ## Model probes
151
- We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.
152
-
153
- The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
154
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/uRAHYQcwkbgajylTIFbUb.png)
155
-
156
- ## Inference speeds
157
- We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences! The most gains will be seen with PyTorch > 2.5 on linux machines.
158
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/Lu6nWB9Fc-7YTql3Z1hVB.png)
159
-
160
- ### Citation
161
- If you use any of this implementation or work please cite it (as well as the ESMC preprint).
162
-
163
- ```
164
- @misc {FastPLMs,
165
- author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
166
- title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
167
- year = {2024},
168
- url = { https://huggingface.co/Synthyra/ESMplusplus_small },
169
- DOI = { 10.57967/hf/3726 },
170
- publisher = { Hugging Face }
171
- }
 
 
 
 
 
 
 
 
 
 
172
  ```
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # NOTE
7
+ The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8
+
9
+ # ESM++
10
+ [ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package.
11
+ The large version corresponds to the 600 million parameter version of ESMC.
12
+
13
+ ## Attention backend defaults
14
+ `sdpa` is the default attention backend for ESM++.
15
+
16
+ To enable Flex Attention, set `attn_backend="flex"` in the config before loading the model:
17
+
18
+ ```python
19
+ from transformers import AutoConfig, AutoModelForMaskedLM
20
+
21
+ config = AutoConfig.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
22
+ config.attn_backend = "flex"
23
+ model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', config=config, trust_remote_code=True)
24
+ ```
25
+
26
+ For throughput and memory efficiency, `torch.compile(...)` is heavily recommended, especially when using Flex Attention.
27
+
28
+
29
+ ## Use with 🤗 transformers
30
+ ```python
31
+ from transformers import AutoModelForMaskedLM
32
+ model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
33
+ tokenizer = model.tokenizer
34
+
35
+ sequences = ['MPRTEIN', 'MSEQWENCE']
36
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
37
+
38
+ # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
39
+
40
+ output = model(**tokenized) # get all hidden states with output_hidden_states=True
41
+ print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
42
+ print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 1152)
43
+ print(output.loss) # language modeling loss if you passed labels
44
+ #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
45
+ ```
46
+
47
+ ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.
48
+
49
+ ```python
50
+ from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
51
+
52
+ model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
53
+ logits = model(**tokenized).logits
54
+ print(logits.shape) # (batch_size, num_labels), (2, 2)
55
+ ```
56
+
57
+ ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this:
58
+ ```python
59
+ import torch
60
+ model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True, dtype=torch.float16) # or torch.bfloat16
61
+ ```
62
+
63
+ ## Embed entire datasets with no new code
64
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
65
+
66
+ Example:
67
+ ```python
68
+ embedding_dict = model.embed_dataset(
69
+ sequences=[
70
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
71
+ ],
72
+ tokenizer=model.tokenizer,
73
+ batch_size=2, # adjust for your GPU memory
74
+ max_len=512, # adjust for your needs
75
+ full_embeddings=False, # if True, no pooling is performed
76
+ embed_dtype=torch.float32, # cast to what dtype you want
77
+ pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
78
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
79
+ sql=False, # if True, embeddings will be stored in SQLite database
80
+ sql_db_path='embeddings.db',
81
+ save=True, # if True, embeddings will be saved as a .pth file
82
+ save_path='embeddings.pth',
83
+ )
84
+ # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
85
+ ```
86
+
87
+ ```
88
+ model.embed_dataset()
89
+ Args:
90
+ sequences: List of protein sequences
91
+ batch_size: Batch size for processing
92
+ max_len: Maximum sequence length
93
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
94
+ pooling_type: Type of pooling ('mean' or 'cls')
95
+ num_workers: Number of workers for data loading, 0 for the main process
96
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
97
+ sql_db_path: Path to SQLite database
98
+
99
+ Returns:
100
+ Dictionary mapping sequences to embeddings, or None if sql=True
101
+
102
+ Note:
103
+ - If sql=True, embeddings can only be stored in float32
104
+ - sql is ideal if you need to stream a very large dataset for training in real-time
105
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
106
+ - sql will be used if it is True and save is True or False
107
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
108
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
109
+ ```
110
+
111
+ ## Fine-tuning with 🤗 peft
112
+ ```python
113
+ model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
114
+ # these modules handle ESM++ and ESM2 attention layers
115
+ target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
116
+
117
+ lora_config = LoraConfig(
118
+ r=8, # choose lora parameters to your liking
119
+ lora_alpha=16,
120
+ lora_dropout=0.01,
121
+ bias="none",
122
+ target_modules=target_modules,
123
+ )
124
+
125
+ # Apply LoRA to the model
126
+ model = get_peft_model(model, lora_config)
127
+
128
+ # Unfreeze the classifier head
129
+ for param in model.classifier.parameters():
130
+ param.requires_grad = True
131
+ ```
132
+
133
+ For a more thourough example of fine-tuning, check out our example script [here](https://github.com/Synthyra/FastPLMs/blob/main/fine_tuning_example.py).
134
+
135
+
136
+ ## Returning attention maps
137
+ When `attn_backend="flex"`, Flex Attention with a pad-token block mask is used for attention calculations. Optimized attention paths do not return attention maps directly.
138
+ ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps.
139
+
140
+ ```python
141
+ output = model(**tokenized, output_attentions=True)
142
+ att = output.attentions
143
+ len(att) # 33, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
144
+ ```
145
+
146
+ ## Comparison across floating-point precision and implementations
147
+ We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
148
+ Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision.
149
+
150
+ Average MSE for FP16: 0.00000003
151
+
152
+ Average MSE for BF16: 0.00000122
153
+
154
+ We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package.
155
+
156
+ Average MSE of last hidden state: 2.46e-09
157
+
158
+ You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_600m')
159
+
160
+ ## Model probes
161
+ We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.
162
+
163
+ The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
164
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/uRAHYQcwkbgajylTIFbUb.png)
165
+
166
+ ## Inference speeds
167
+ We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences! The most gains will be seen with PyTorch > 2.5 on linux machines.
168
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/Lu6nWB9Fc-7YTql3Z1hVB.png)
169
+
170
+ ### Citation
171
+ If you use any of this implementation or work please cite it (as well as the ESMC preprint).
172
+
173
+ ```
174
+ @misc {FastPLMs,
175
+ author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
176
+ title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
177
+ year = {2024},
178
+ url = { https://huggingface.co/Synthyra/ESMplusplus_small },
179
+ DOI = { 10.57967/hf/3726 },
180
+ publisher = { Hugging Face }
181
+ }
182
  ```