Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -11,14 +11,28 @@ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a new
|
|
| 11 |
|
| 12 |
Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
|
| 13 |
|
| 14 |
-
## Attention
|
| 15 |
-
`sdpa` is the default attention backend for FastESM.
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
Various other optimizations also make the base implementation slightly different than the one in transformers.
|
| 23 |
|
| 24 |
## Use with 🤗 transformers
|
|
|
|
| 11 |
|
| 12 |
Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
|
| 13 |
|
| 14 |
+
## Attention backends
|
|
|
|
| 15 |
|
| 16 |
+
`sdpa` (PyTorch Scaled Dot Product Attention) is the default. It is fast, memory-efficient, and numerically equivalent to naive attention. The backend is set via `config.attn_backend` before loading.
|
| 17 |
|
| 18 |
+
| Backend | Key | Notes |
|
| 19 |
+
| :--- | :--- | :--- |
|
| 20 |
+
| PyTorch SDPA | `"sdpa"` | Default. Exact numerics, stable on all hardware. |
|
| 21 |
+
| Flash Attention | `"kernels_flash"` | Fastest. Requires `pip install kernels` (pre-built — no hours-long compilation). Outputs differ slightly from SDPA (online softmax reordering) but are numerically harmless. |
|
| 22 |
+
| Flex Attention | `"flex"` | Skips padding tokens via block mask — faster on variable-length batches. Near-exact numerics. First use compiles a Triton kernel (30–120 s). |
|
| 23 |
+
| Auto | `"auto"` | Picks the best available: `kernels_flash` → `flex` → `sdpa`. |
|
| 24 |
|
| 25 |
+
```python
|
| 26 |
+
from transformers import AutoConfig, AutoModel
|
| 27 |
+
|
| 28 |
+
config = AutoConfig.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True)
|
| 29 |
+
config.attn_backend = "flex" # or "kernels_flash", "sdpa", "auto"
|
| 30 |
+
model = AutoModel.from_pretrained("Synthyra/ESM2-150M", config=config, trust_remote_code=True)
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
`torch.compile(model)` is heavily recommended for sustained throughput, especially with Flex Attention.
|
| 34 |
+
|
| 35 |
+
Attention maps (`output_attentions=True`) are supported with all backends. For SDPA, Flash, and Flex, the attention weights are computed via a separate naive pass, so there is no memory benefit to enabling it during normal inference.
|
| 36 |
Various other optimizations also make the base implementation slightly different than the one in transformers.
|
| 37 |
|
| 38 |
## Use with 🤗 transformers
|