Spaces:
Running on Zero
Running on Zero
hanjian.thu123 commited on
Commit ·
b54fafe
1
Parent(s): 5e08e4d
[update] slow attn
Browse files- .gitignore +1 -0
- grn/models/basic.py +1 -18
- requirements.txt +0 -1
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
grn/models/basic.py
CHANGED
|
@@ -201,24 +201,7 @@ class SelfAttention(nn.Module):
|
|
| 201 |
if self.use_flex_attn and attn_fn is not None:
|
| 202 |
attn_output = attn_fn(query_states.to(value_states.dtype), key_states.to(value_states.dtype), value_states, scale=scale).transpose(1, 2).reshape(B, L, C)
|
| 203 |
else:
|
| 204 |
-
|
| 205 |
-
# fa2, flash_attn_func input/output should be (batch_size, seqlen, nheads, headdim)
|
| 206 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 207 |
-
attn_output = flash_attn_varlen_func(
|
| 208 |
-
q = query_states.permute([0,2,1,3]).to(torch.bfloat16).squeeze(0),
|
| 209 |
-
k = key_states.permute([0,2,1,3]).to(torch.bfloat16).squeeze(0),
|
| 210 |
-
v = value_states.permute([0,2,1,3]).to(torch.bfloat16).squeeze(0),
|
| 211 |
-
cu_seqlens_q = torch.tensor([0] + split_cond_uncond, device=query_states.device).cumsum(-1).to(torch.int32),
|
| 212 |
-
cu_seqlens_k = torch.tensor([0] + cu_seqlens_k, device=query_states.device).cumsum(-1).to(torch.int32),
|
| 213 |
-
max_seqlen_q = max(split_cond_uncond),
|
| 214 |
-
max_seqlen_k = max(cu_seqlens_k),
|
| 215 |
-
softmax_scale=scale,
|
| 216 |
-
)
|
| 217 |
-
attn_output = attn_output.reshape(B, L, C)
|
| 218 |
-
# attn_output = flash_attn_func(query_states.permute([0,2,1,3]).to(torch.bfloat16), key_states.permute([0,2,1,3]).to(torch.bfloat16), value_states.permute([0,2,1,3]).to(torch.bfloat16), softmax_scale=scale)
|
| 219 |
-
else:
|
| 220 |
-
# slow attn
|
| 221 |
-
attn_output = slow_attn(query=query_states, key=key_states, value=value_states, scale=scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
|
| 222 |
|
| 223 |
# fa3, flash_attn_func input/output should be (batch_size, seqlen, nheads, headdim)
|
| 224 |
# from flash_attn_interface import flash_attn_qkvpacked_func, flash_attn_func
|
|
|
|
| 201 |
if self.use_flex_attn and attn_fn is not None:
|
| 202 |
attn_output = attn_fn(query_states.to(value_states.dtype), key_states.to(value_states.dtype), value_states, scale=scale).transpose(1, 2).reshape(B, L, C)
|
| 203 |
else:
|
| 204 |
+
attn_output = slow_attn(query=query_states, key=key_states, value=value_states, scale=scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
# fa3, flash_attn_func input/output should be (batch_size, seqlen, nheads, headdim)
|
| 207 |
# from flash_attn_interface import flash_attn_qkvpacked_func, flash_attn_func
|
requirements.txt
CHANGED
|
@@ -15,4 +15,3 @@ ftfy>=6.1.1
|
|
| 15 |
transformers>=4.35.0
|
| 16 |
regex>=2023.10.3
|
| 17 |
pyyaml>=6.0
|
| 18 |
-
flash-attn
|
|
|
|
| 15 |
transformers>=4.35.0
|
| 16 |
regex>=2023.10.3
|
| 17 |
pyyaml>=6.0
|
|
|