hanjian.thu123 commited on
Commit
b54fafe
·
1 Parent(s): 5e08e4d

[update] slow attn

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. grn/models/basic.py +1 -18
  3. requirements.txt +0 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
grn/models/basic.py CHANGED
@@ -201,24 +201,7 @@ class SelfAttention(nn.Module):
201
  if self.use_flex_attn and attn_fn is not None:
202
  attn_output = attn_fn(query_states.to(value_states.dtype), key_states.to(value_states.dtype), value_states, scale=scale).transpose(1, 2).reshape(B, L, C)
203
  else:
204
- if attn_bias_or_two_vector is None:
205
- # fa2, flash_attn_func input/output should be (batch_size, seqlen, nheads, headdim)
206
- from flash_attn import flash_attn_func, flash_attn_varlen_func
207
- attn_output = flash_attn_varlen_func(
208
- q = query_states.permute([0,2,1,3]).to(torch.bfloat16).squeeze(0),
209
- k = key_states.permute([0,2,1,3]).to(torch.bfloat16).squeeze(0),
210
- v = value_states.permute([0,2,1,3]).to(torch.bfloat16).squeeze(0),
211
- cu_seqlens_q = torch.tensor([0] + split_cond_uncond, device=query_states.device).cumsum(-1).to(torch.int32),
212
- cu_seqlens_k = torch.tensor([0] + cu_seqlens_k, device=query_states.device).cumsum(-1).to(torch.int32),
213
- max_seqlen_q = max(split_cond_uncond),
214
- max_seqlen_k = max(cu_seqlens_k),
215
- softmax_scale=scale,
216
- )
217
- attn_output = attn_output.reshape(B, L, C)
218
- # attn_output = flash_attn_func(query_states.permute([0,2,1,3]).to(torch.bfloat16), key_states.permute([0,2,1,3]).to(torch.bfloat16), value_states.permute([0,2,1,3]).to(torch.bfloat16), softmax_scale=scale)
219
- else:
220
- # slow attn
221
- attn_output = slow_attn(query=query_states, key=key_states, value=value_states, scale=scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
222
 
223
  # fa3, flash_attn_func input/output should be (batch_size, seqlen, nheads, headdim)
224
  # from flash_attn_interface import flash_attn_qkvpacked_func, flash_attn_func
 
201
  if self.use_flex_attn and attn_fn is not None:
202
  attn_output = attn_fn(query_states.to(value_states.dtype), key_states.to(value_states.dtype), value_states, scale=scale).transpose(1, 2).reshape(B, L, C)
203
  else:
204
+ attn_output = slow_attn(query=query_states, key=key_states, value=value_states, scale=scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  # fa3, flash_attn_func input/output should be (batch_size, seqlen, nheads, headdim)
207
  # from flash_attn_interface import flash_attn_qkvpacked_func, flash_attn_func
requirements.txt CHANGED
@@ -15,4 +15,3 @@ ftfy>=6.1.1
15
  transformers>=4.35.0
16
  regex>=2023.10.3
17
  pyyaml>=6.0
18
- flash-attn
 
15
  transformers>=4.35.0
16
  regex>=2023.10.3
17
  pyyaml>=6.0