lhallee commited on
Commit
59e62da
Β·
verified Β·
1 Parent(s): a9cfa36

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +164 -164
README.md CHANGED
@@ -1,165 +1,165 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
-
6
- # NOTE
7
- The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8
-
9
- # FastESM
10
- FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11
-
12
- Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
-
14
- ## Attention backends
15
-
16
- `sdpa` (PyTorch Scaled Dot Product Attention) is the default. It is fast, memory-efficient, and numerically equivalent to naive attention. The backend is set via `config.attn_backend` before loading.
17
-
18
- | Backend | Key | Notes |
19
- | :--- | :--- | :--- |
20
- | PyTorch SDPA | `"sdpa"` | Default. Exact numerics, stable on all hardware. |
21
- | Flash Attention | `"kernels_flash"` | Fastest. Requires `pip install kernels` (pre-built β€” no hours-long compilation). Outputs are not bitwise identical to SDPA due to online softmax reordering; differences are often small but not guaranteed to be inconsequential β€” use `"sdpa"` if exact numerics matter. |
22
- | Flex Attention | `"flex"` | Skips padding tokens via block mask β€” faster on variable-length batches. Near-exact numerics. First use compiles a Triton kernel (30–120 s). |
23
- | Auto | `"auto"` | Picks the best available: `kernels_flash` β†’ `flex` β†’ `sdpa`. |
24
-
25
- ```python
26
- from transformers import AutoConfig, AutoModel
27
-
28
- config = AutoConfig.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True)
29
- config.attn_backend = "flex" # or "kernels_flash", "sdpa", "auto"
30
- model = AutoModel.from_pretrained("Synthyra/ESM2-150M", config=config, trust_remote_code=True)
31
- ```
32
-
33
- `torch.compile(model)` is heavily recommended for sustained throughput, especially with Flex Attention.
34
-
35
- Attention maps (`output_attentions=True`) are supported with all backends. For SDPA, Flash, and Flex, the attention weights are computed via a separate naive pass, so there is no memory benefit to enabling it during normal inference.
36
- Various other optimizations also make the base implementation slightly different than the one in transformers.
37
-
38
- ## Use with πŸ€— transformers
39
-
40
- ### Supported models
41
- ```python
42
- model_dict = {
43
- # Synthyra/ESM2-8M
44
- 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
45
- # Synthyra/ESM2-35M
46
- 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
47
- # Synthyra/ESM2-150M
48
- 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
49
- # Synthyra/ESM2-650M
50
- 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
51
- # Synthyra/ESM2-3B
52
- 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
53
- }
54
- ```
55
-
56
- ### For working with embeddings
57
- ```python
58
- import torch
59
- from transformers import AutoModel, AutoTokenizer
60
-
61
- model_path = 'Synthyra/ESM2-8M'
62
- model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
63
- tokenizer = model.tokenizer
64
-
65
- sequences = ['MPRTEIN', 'MSEQWENCE']
66
- tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
67
- with torch.no_grad():
68
- embeddings = model(**tokenized).last_hidden_state
69
-
70
- print(embeddings.shape) # (2, 11, 1280)
71
- ```
72
-
73
- ### For working with sequence logits
74
- ```python
75
- import torch
76
- from transformers import AutoModelForMaskedLM, AutoTokenizer
77
-
78
- model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
79
- with torch.no_grad():
80
- logits = model(**tokenized).logits
81
-
82
- print(logits.shape) # (2, 11, 33)
83
- ```
84
-
85
- ### For working with attention maps
86
- ```python
87
- import torch
88
- from transformers import AutoModel, AutoTokenizer
89
-
90
- model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
91
- with torch.no_grad():
92
- attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
93
-
94
- print(attentions[-1].shape) # (2, 20, 11, 11)
95
- ```
96
-
97
- ### Contact prediction
98
- Because we can output attentions using the naive attention implementation, the contact prediction is also supported
99
- ```python
100
- with torch.no_grad():
101
- contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
102
- ```
103
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
104
-
105
- ## Embed entire datasets with no new code
106
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
107
-
108
- Example:
109
- ```python
110
- embedding_dict = model.embed_dataset(
111
- sequences=[
112
- 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
113
- ],
114
- tokenizer=model.tokenizer,
115
- batch_size=2, # adjust for your GPU memory
116
- max_len=512, # adjust for your needs
117
- full_embeddings=False, # if True, no pooling is performed
118
- embed_dtype=torch.float32, # cast to what dtype you want
119
- pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
120
- num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
121
- sql=False, # if True, embeddings will be stored in SQLite database
122
- sql_db_path='embeddings.db',
123
- save=True, # if True, embeddings will be saved as a .pth file
124
- save_path='embeddings.pth',
125
- )
126
- # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
127
- ```
128
-
129
- ```
130
- model.embed_dataset()
131
- Args:
132
- sequences: List of protein sequences
133
- batch_size: Batch size for processing
134
- max_len: Maximum sequence length
135
- full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
136
- pooling_type: Type of pooling ('mean' or 'cls')
137
- num_workers: Number of workers for data loading, 0 for the main process
138
- sql: Whether to store embeddings in SQLite database - will be stored in float32
139
- sql_db_path: Path to SQLite database
140
-
141
- Returns:
142
- Dictionary mapping sequences to embeddings, or None if sql=True
143
-
144
- Note:
145
- - If sql=True, embeddings can only be stored in float32
146
- - sql is ideal if you need to stream a very large dataset for training in real-time
147
- - save=True is ideal if you can store the entire embedding dictionary in RAM
148
- - sql will be used if it is True and save is True or False
149
- - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
150
- - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
151
- ```
152
-
153
-
154
- ### Citation
155
- If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
156
- ```
157
- @misc {FastPLMs,
158
- author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
159
- title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
160
- year = {2024},
161
- url = { https://huggingface.co/Synthyra/ESMplusplus_small },
162
- DOI = { 10.57967/hf/3726 },
163
- publisher = { Hugging Face }
164
- }
165
  ```
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # NOTE
7
+ The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8
+
9
+ # FastESM
10
+ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11
+
12
+ Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
+
14
+ ## Attention backends
15
+
16
+ `sdpa` (PyTorch Scaled Dot Product Attention) is the default. It is fast, memory-efficient, and numerically equivalent to naive attention. The backend is set via `config.attn_backend` before loading.
17
+
18
+ | Backend | Key | Notes |
19
+ | :--- | :--- | :--- |
20
+ | PyTorch SDPA | `"sdpa"` | Default. Exact numerics, stable on all hardware. |
21
+ | Flash Attention | `"kernels_flash"` | Fastest. Requires `pip install kernels` (pre-built β€” no hours-long compilation). Outputs are not bitwise identical to SDPA due to online softmax reordering; differences are often small but not guaranteed to be inconsequential β€” use `"sdpa"` if exact numerics matter. |
22
+ | Flex Attention | `"flex"` | Skips padding tokens via block mask β€” faster on variable-length batches. Near-exact numerics. First use compiles a Triton kernel (30–120 s). |
23
+ | Auto | `"auto"` | Picks the best available: `kernels_flash` β†’ `flex` β†’ `sdpa`. |
24
+
25
+ ```python
26
+ from transformers import AutoConfig, AutoModel
27
+
28
+ config = AutoConfig.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True)
29
+ config.attn_backend = "flex" # or "kernels_flash", "sdpa", "auto"
30
+ model = AutoModel.from_pretrained("Synthyra/ESM2-150M", config=config, trust_remote_code=True)
31
+ ```
32
+
33
+ `torch.compile(model)` is heavily recommended for sustained throughput, especially with Flex Attention.
34
+
35
+ Attention maps (`output_attentions=True`) are supported with all backends. For SDPA, Flash, and Flex, the attention weights are computed via a separate naive pass, so there is no memory benefit to enabling it during normal inference.
36
+ Various other optimizations also make the base implementation slightly different than the one in transformers.
37
+
38
+ ## Use with πŸ€— transformers
39
+
40
+ ### Supported models
41
+ ```python
42
+ model_dict = {
43
+ # Synthyra/ESM2-8M
44
+ 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
45
+ # Synthyra/ESM2-35M
46
+ 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
47
+ # Synthyra/ESM2-150M
48
+ 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
49
+ # Synthyra/ESM2-650M
50
+ 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
51
+ # Synthyra/ESM2-3B
52
+ 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
53
+ }
54
+ ```
55
+
56
+ ### For working with embeddings
57
+ ```python
58
+ import torch
59
+ from transformers import AutoModel, AutoTokenizer
60
+
61
+ model_path = 'Synthyra/ESM2-8M'
62
+ model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
63
+ tokenizer = model.tokenizer
64
+
65
+ sequences = ['MPRTEIN', 'MSEQWENCE']
66
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
67
+ with torch.no_grad():
68
+ embeddings = model(**tokenized).last_hidden_state
69
+
70
+ print(embeddings.shape) # (2, 11, 1280)
71
+ ```
72
+
73
+ ### For working with sequence logits
74
+ ```python
75
+ import torch
76
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
77
+
78
+ model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
79
+ with torch.no_grad():
80
+ logits = model(**tokenized).logits
81
+
82
+ print(logits.shape) # (2, 11, 33)
83
+ ```
84
+
85
+ ### For working with attention maps
86
+ ```python
87
+ import torch
88
+ from transformers import AutoModel, AutoTokenizer
89
+
90
+ model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
91
+ with torch.no_grad():
92
+ attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
93
+
94
+ print(attentions[-1].shape) # (2, 20, 11, 11)
95
+ ```
96
+
97
+ ### Contact prediction
98
+ Because we can output attentions using the naive attention implementation, the contact prediction is also supported
99
+ ```python
100
+ with torch.no_grad():
101
+ contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
102
+ ```
103
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
104
+
105
+ ## Embed entire datasets with no new code
106
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
107
+
108
+ Example:
109
+ ```python
110
+ embedding_dict = model.embed_dataset(
111
+ sequences=[
112
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
113
+ ],
114
+ tokenizer=model.tokenizer,
115
+ batch_size=2, # adjust for your GPU memory
116
+ max_len=512, # adjust for your needs
117
+ full_embeddings=False, # if True, no pooling is performed
118
+ embed_dtype=torch.float32, # cast to what dtype you want
119
+ pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
120
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
121
+ sql=False, # if True, embeddings will be stored in SQLite database
122
+ sql_db_path='embeddings.db',
123
+ save=True, # if True, embeddings will be saved as a .pth file
124
+ save_path='embeddings.pth',
125
+ )
126
+ # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
127
+ ```
128
+
129
+ ```
130
+ model.embed_dataset()
131
+ Args:
132
+ sequences: List of protein sequences
133
+ batch_size: Batch size for processing
134
+ max_len: Maximum sequence length
135
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
136
+ pooling_type: Type of pooling ('mean' or 'cls')
137
+ num_workers: Number of workers for data loading, 0 for the main process
138
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
139
+ sql_db_path: Path to SQLite database
140
+
141
+ Returns:
142
+ Dictionary mapping sequences to embeddings, or None if sql=True
143
+
144
+ Note:
145
+ - If sql=True, embeddings can only be stored in float32
146
+ - sql is ideal if you need to stream a very large dataset for training in real-time
147
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
148
+ - sql will be used if it is True and save is True or False
149
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
150
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
151
+ ```
152
+
153
+
154
+ ### Citation
155
+ If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
156
+ ```
157
+ @misc {FastPLMs,
158
+ author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
159
+ title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
160
+ year = {2024},
161
+ url = { https://huggingface.co/Synthyra/ESMplusplus_small },
162
+ DOI = { 10.57967/hf/3726 },
163
+ publisher = { Hugging Face }
164
+ }
165
  ```