lhallee commited on
Commit
e672711
·
verified ·
1 Parent(s): b0584b5

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +19 -5
README.md CHANGED
@@ -11,14 +11,28 @@ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a new
11
 
12
  Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
 
14
- ## Attention backend defaults
15
- `sdpa` is the default attention backend for FastESM.
16
 
17
- To enable Flex Attention, set `attn_backend="flex"` on the config before model initialization/loading.
18
 
19
- For throughput and memory efficiency, `torch.compile(...)` is heavily recommended, especially when using Flex Attention.
 
 
 
 
 
20
 
21
- Outputting attention maps (or the contact prediction head) is not natively possible with the optimized attention backends (including Flex Attention). You can still pass ```output_attentions``` to have attention calculated manually and returned.
 
 
 
 
 
 
 
 
 
 
22
  Various other optimizations also make the base implementation slightly different than the one in transformers.
23
 
24
  ## Use with 🤗 transformers
 
11
 
12
  Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
 
14
+ ## Attention backends
 
15
 
16
+ `sdpa` (PyTorch Scaled Dot Product Attention) is the default. It is fast, memory-efficient, and numerically equivalent to naive attention. The backend is set via `config.attn_backend` before loading.
17
 
18
+ | Backend | Key | Notes |
19
+ | :--- | :--- | :--- |
20
+ | PyTorch SDPA | `"sdpa"` | Default. Exact numerics, stable on all hardware. |
21
+ | Flash Attention | `"kernels_flash"` | Fastest. Requires `pip install kernels` (pre-built — no hours-long compilation). Outputs differ slightly from SDPA (online softmax reordering) but are numerically harmless. |
22
+ | Flex Attention | `"flex"` | Skips padding tokens via block mask — faster on variable-length batches. Near-exact numerics. First use compiles a Triton kernel (30–120 s). |
23
+ | Auto | `"auto"` | Picks the best available: `kernels_flash` → `flex` → `sdpa`. |
24
 
25
+ ```python
26
+ from transformers import AutoConfig, AutoModel
27
+
28
+ config = AutoConfig.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True)
29
+ config.attn_backend = "flex" # or "kernels_flash", "sdpa", "auto"
30
+ model = AutoModel.from_pretrained("Synthyra/ESM2-150M", config=config, trust_remote_code=True)
31
+ ```
32
+
33
+ `torch.compile(model)` is heavily recommended for sustained throughput, especially with Flex Attention.
34
+
35
+ Attention maps (`output_attentions=True`) are supported with all backends. For SDPA, Flash, and Flex, the attention weights are computed via a separate naive pass, so there is no memory benefit to enabling it during normal inference.
36
  Various other optimizations also make the base implementation slightly different than the one in transformers.
37
 
38
  ## Use with 🤗 transformers