lhallee commited on
Commit
309b571
·
verified ·
1 Parent(s): 0004bb7

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +25 -2
README.md CHANGED
@@ -36,8 +36,31 @@ with torch.no_grad():
36
  logits = mlm(**batch).logits
37
  ```
38
 
39
- ## Attention backend
40
- `sdpa` is the default backend. Flex Attention is available by setting `config.attn_backend = "flex"` before loading.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  ## Embed datasets
43
  All DPLM models inherit `EmbeddingMixin`, so you can call `model.embed_dataset(...)` directly.
 
36
  logits = mlm(**batch).logits
37
  ```
38
 
39
+ ## Attention backends
40
+
41
+ `sdpa` (PyTorch Scaled Dot Product Attention) is the default.
42
+
43
+ | Backend | Key | Notes |
44
+ | :--- | :--- | :--- |
45
+ | PyTorch SDPA | `"sdpa"` | Default. Exact numerics, stable on all hardware. |
46
+ | Flash Attention | `"kernels_flash"` | Fastest on Ampere/Hopper GPUs. Requires `pip install kernels` (pre-built — no hours-long compilation). Outputs differ slightly from SDPA due to online softmax reordering, but differences are numerically harmless. |
47
+ | Flex Attention | `"flex"` | Skips padding tokens via block mask — faster on variable-length batches. Near-exact numerics. First use compiles a Triton kernel (30–120 s). Best combined with `torch.compile`. |
48
+ | Auto | `"auto"` | Picks the best available: `kernels_flash` → `flex` → `sdpa`. |
49
+
50
+ Set via config before loading, or change on the model after loading (DPLM propagates the change to all attention layers immediately):
51
+
52
+ ```python
53
+ from transformers import AutoConfig, AutoModel
54
+
55
+ # Option 1: set before loading
56
+ config = AutoConfig.from_pretrained("Synthyra/DPLM-150M", trust_remote_code=True)
57
+ config.attn_backend = "flex"
58
+ model = AutoModel.from_pretrained("Synthyra/DPLM-150M", config=config, trust_remote_code=True)
59
+
60
+ # Option 2: set after loading
61
+ model = AutoModel.from_pretrained("Synthyra/DPLM-150M", trust_remote_code=True)
62
+ model.attn_backend = "flex" # propagates to all attention layers in-place
63
+ ```
64
 
65
  ## Embed datasets
66
  All DPLM models inherit `EmbeddingMixin`, so you can call `model.embed_dataset(...)` directly.