Sentence Similarity
Transformers
Safetensors
English
mistral
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
text-generation-inference
text-embeddings-inference
Update attn_mask_utils.py
Browse files- attn_mask_utils.py +29 -7
attn_mask_utils.py
CHANGED
|
@@ -1,7 +1,19 @@
|
|
| 1 |
from typing import List, Optional, Tuple, Union
|
| 2 |
import torch
|
|
|
|
|
|
|
| 3 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
def _prepare_4d_attention_mask_for_sdpa(
|
| 6 |
attention_mask: Optional[torch.Tensor],
|
| 7 |
input_shape: Union[torch.Size, Tuple, List],
|
|
@@ -59,9 +71,14 @@ def _prepare_4d_attention_mask_for_sdpa(
|
|
| 59 |
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
| 60 |
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
| 61 |
if query_length > 1:
|
| 62 |
-
|
| 63 |
-
expanded_4d_mask
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
return expanded_4d_mask
|
| 67 |
|
|
@@ -195,8 +212,13 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|
| 195 |
# controlflow that can not be captured properly.
|
| 196 |
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
|
| 197 |
if query_length > 1 and not is_tracing:
|
| 198 |
-
|
| 199 |
-
expanded_4d_mask
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
return expanded_4d_mask
|
|
|
|
| 1 |
from typing import List, Optional, Tuple, Union
|
| 2 |
import torch
|
| 3 |
+
from packaging import version
|
| 4 |
+
import importlib.metadata
|
| 5 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 6 |
|
| 7 |
+
from transformers.utils.import_utils import _is_package_available
|
| 8 |
+
|
| 9 |
+
def is_transformers_attn_greater_or_equal_4_39():
|
| 10 |
+
if not _is_package_available("transformers"):
|
| 11 |
+
return False
|
| 12 |
+
|
| 13 |
+
return version.parse(importlib.metadata.version("transformers")) >= version.parse(
|
| 14 |
+
"4.39.0"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
def _prepare_4d_attention_mask_for_sdpa(
|
| 18 |
attention_mask: Optional[torch.Tensor],
|
| 19 |
input_shape: Union[torch.Size, Tuple, List],
|
|
|
|
| 71 |
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
| 72 |
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
| 73 |
if query_length > 1:
|
| 74 |
+
if is_transformers_attn_greater_or_equal_4_39():
|
| 75 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
| 76 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
| 80 |
+
expanded_4d_mask, attention_mask, unmasked_value=0.0
|
| 81 |
+
)
|
| 82 |
|
| 83 |
return expanded_4d_mask
|
| 84 |
|
|
|
|
| 212 |
# controlflow that can not be captured properly.
|
| 213 |
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
|
| 214 |
if query_length > 1 and not is_tracing:
|
| 215 |
+
if is_transformers_attn_greater_or_equal_4_39():
|
| 216 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
| 217 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
| 221 |
+
expanded_4d_mask, attention_mask, unmasked_value=0.0
|
| 222 |
+
)
|
| 223 |
|
| 224 |
+
return expanded_4d_mask
|