davda54 commited on
Commit
64f341a
·
verified ·
1 Parent(s): c96534d

fix flasshattention

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +8 -6
modeling_gptbert.py CHANGED
@@ -270,6 +270,7 @@ def flash_attention_forward(
270
  rotary_emb: UnpaddedRotaryEmbedding,
271
  cu_seqlens: torch.Tensor,
272
  max_seqlen: int,
 
273
  local_attention: Tuple[int, int],
274
  dropout_p: float,
275
  deterministic: bool,
@@ -289,9 +290,9 @@ def flash_attention_forward(
289
  qkv,
290
  cu_seqlens=cu_seqlens,
291
  max_seqlen=max_seqlen,
292
- # dropout_p=dropout_p,
293
- # deterministic=deterministic,
294
- # window_size=local_attention,
295
  causal=False
296
  )
297
  attn = attn.to(orig_dtype) # type: ignore
@@ -300,9 +301,9 @@ def flash_attention_forward(
300
  qkv,
301
  cu_seqlens=cu_seqlens,
302
  max_seqlen=max_seqlen,
303
- # dropout_p=dropout_p,
304
- # deterministic=deterministic,
305
- # window_size=local_attention,
306
  causal=False
307
  )
308
  return attn
@@ -460,6 +461,7 @@ class SelfAttention(nn.Module):
460
  self.rope_embedding,
461
  cu_seqlens,
462
  max_seqlen,
 
463
  local_attention,
464
  self.attention_dropout if self.training else 0.0,
465
  self.deterministic_flash_attn
 
270
  rotary_emb: UnpaddedRotaryEmbedding,
271
  cu_seqlens: torch.Tensor,
272
  max_seqlen: int,
273
+ causal: bool,
274
  local_attention: Tuple[int, int],
275
  dropout_p: float,
276
  deterministic: bool,
 
290
  qkv,
291
  cu_seqlens=cu_seqlens,
292
  max_seqlen=max_seqlen,
293
+ dropout_p=dropout_p,
294
+ deterministic=deterministic,
295
+ window_size=local_attention,
296
  causal=False
297
  )
298
  attn = attn.to(orig_dtype) # type: ignore
 
301
  qkv,
302
  cu_seqlens=cu_seqlens,
303
  max_seqlen=max_seqlen,
304
+ dropout_p=dropout_p,
305
+ deterministic=deterministic,
306
+ window_size=local_attention,
307
  causal=False
308
  )
309
  return attn
 
461
  self.rope_embedding,
462
  cu_seqlens,
463
  max_seqlen,
464
+ self.is_causal,
465
  local_attention,
466
  self.attention_dropout if self.training else 0.0,
467
  self.deterministic_flash_attn