fix flasshattention
Browse files- modeling_gptbert.py +8 -6
modeling_gptbert.py
CHANGED
|
@@ -270,6 +270,7 @@ def flash_attention_forward(
|
|
| 270 |
rotary_emb: UnpaddedRotaryEmbedding,
|
| 271 |
cu_seqlens: torch.Tensor,
|
| 272 |
max_seqlen: int,
|
|
|
|
| 273 |
local_attention: Tuple[int, int],
|
| 274 |
dropout_p: float,
|
| 275 |
deterministic: bool,
|
|
@@ -289,9 +290,9 @@ def flash_attention_forward(
|
|
| 289 |
qkv,
|
| 290 |
cu_seqlens=cu_seqlens,
|
| 291 |
max_seqlen=max_seqlen,
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
causal=False
|
| 296 |
)
|
| 297 |
attn = attn.to(orig_dtype) # type: ignore
|
|
@@ -300,9 +301,9 @@ def flash_attention_forward(
|
|
| 300 |
qkv,
|
| 301 |
cu_seqlens=cu_seqlens,
|
| 302 |
max_seqlen=max_seqlen,
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
causal=False
|
| 307 |
)
|
| 308 |
return attn
|
|
@@ -460,6 +461,7 @@ class SelfAttention(nn.Module):
|
|
| 460 |
self.rope_embedding,
|
| 461 |
cu_seqlens,
|
| 462 |
max_seqlen,
|
|
|
|
| 463 |
local_attention,
|
| 464 |
self.attention_dropout if self.training else 0.0,
|
| 465 |
self.deterministic_flash_attn
|
|
|
|
| 270 |
rotary_emb: UnpaddedRotaryEmbedding,
|
| 271 |
cu_seqlens: torch.Tensor,
|
| 272 |
max_seqlen: int,
|
| 273 |
+
causal: bool,
|
| 274 |
local_attention: Tuple[int, int],
|
| 275 |
dropout_p: float,
|
| 276 |
deterministic: bool,
|
|
|
|
| 290 |
qkv,
|
| 291 |
cu_seqlens=cu_seqlens,
|
| 292 |
max_seqlen=max_seqlen,
|
| 293 |
+
dropout_p=dropout_p,
|
| 294 |
+
deterministic=deterministic,
|
| 295 |
+
window_size=local_attention,
|
| 296 |
causal=False
|
| 297 |
)
|
| 298 |
attn = attn.to(orig_dtype) # type: ignore
|
|
|
|
| 301 |
qkv,
|
| 302 |
cu_seqlens=cu_seqlens,
|
| 303 |
max_seqlen=max_seqlen,
|
| 304 |
+
dropout_p=dropout_p,
|
| 305 |
+
deterministic=deterministic,
|
| 306 |
+
window_size=local_attention,
|
| 307 |
causal=False
|
| 308 |
)
|
| 309 |
return attn
|
|
|
|
| 461 |
self.rope_embedding,
|
| 462 |
cu_seqlens,
|
| 463 |
max_seqlen,
|
| 464 |
+
self.is_causal,
|
| 465 |
local_attention,
|
| 466 |
self.attention_dropout if self.training else 0.0,
|
| 467 |
self.deterministic_flash_attn
|