Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +32 -31
modeling_esm_plusplus.py
CHANGED
|
@@ -36,8 +36,13 @@ try:
|
|
| 36 |
# when used from AutoModel, these are in the same directory
|
| 37 |
from .embedding_mixin import EmbeddingMixin, Pooler
|
| 38 |
except:
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
|
|
@@ -384,35 +389,27 @@ class MultiHeadAttention(nn.Module):
|
|
| 384 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 385 |
context_BHLD = torch.matmul(attn_weights, value_BHLD)
|
| 386 |
else:
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
and flex_dtype_supported
|
| 396 |
-
and (attention_mask is None or flex_block_mask is not None)
|
| 397 |
-
)
|
| 398 |
-
if use_flex:
|
| 399 |
-
try:
|
| 400 |
-
context_BHLD = flex_attention(
|
| 401 |
-
query_BHLD,
|
| 402 |
-
key_BHLD,
|
| 403 |
-
value_BHLD,
|
| 404 |
-
block_mask=flex_block_mask,
|
| 405 |
-
scale=scale,
|
| 406 |
-
)
|
| 407 |
-
except Exception:
|
| 408 |
-
context_BHLD = F.scaled_dot_product_attention(
|
| 409 |
-
query_BHLD,
|
| 410 |
-
key_BHLD,
|
| 411 |
-
value_BHLD,
|
| 412 |
-
attn_mask=sdpa_mask,
|
| 413 |
-
scale=scale,
|
| 414 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
context_BHLD = F.scaled_dot_product_attention(
|
| 417 |
query_BHLD,
|
| 418 |
key_BHLD,
|
|
@@ -577,11 +574,15 @@ class TransformerStack(nn.Module):
|
|
| 577 |
if attention_mask is not None:
|
| 578 |
assert attention_mask.ndim == 2, f"Expected 2D token attention mask, got shape {attention_mask.shape}."
|
| 579 |
token_attention_mask = attention_mask.bool()
|
| 580 |
-
pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
|
| 581 |
-
attention_mask = pairwise_attention_mask.unsqueeze(1)
|
| 582 |
if self.attn_backend == "flex" and not output_attentions:
|
|
|
|
|
|
|
|
|
|
| 583 |
flex_block_mask = _create_pad_block_mask(token_attention_mask)
|
|
|
|
| 584 |
else:
|
|
|
|
|
|
|
| 585 |
flex_block_mask = None
|
| 586 |
else:
|
| 587 |
flex_block_mask = None
|
|
|
|
| 36 |
# when used from AutoModel, these are in the same directory
|
| 37 |
from .embedding_mixin import EmbeddingMixin, Pooler
|
| 38 |
except:
|
| 39 |
+
try:
|
| 40 |
+
# whem importing as a submodule, embedding mixin is in the FastPLMs directory
|
| 41 |
+
from ..embedding_mixin import EmbeddingMixin, Pooler
|
| 42 |
+
except:
|
| 43 |
+
# when running from our repo, these are in the base directory
|
| 44 |
+
from embedding_mixin import EmbeddingMixin, Pooler
|
| 45 |
+
|
| 46 |
|
| 47 |
|
| 48 |
def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
|
|
|
|
| 389 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 390 |
context_BHLD = torch.matmul(attn_weights, value_BHLD)
|
| 391 |
else:
|
| 392 |
+
if self.attn_backend == "flex":
|
| 393 |
+
assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
|
| 394 |
+
assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
|
| 395 |
+
f"Flex attention backend requires float16 or bfloat16, got {query_BHLD.dtype}."
|
| 396 |
+
)
|
| 397 |
+
if attention_mask is not None:
|
| 398 |
+
assert flex_block_mask is not None, (
|
| 399 |
+
"Flex attention backend requires a block mask when attention_mask is provided."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
)
|
| 401 |
+
context_BHLD = flex_attention(
|
| 402 |
+
query_BHLD,
|
| 403 |
+
key_BHLD,
|
| 404 |
+
value_BHLD,
|
| 405 |
+
block_mask=flex_block_mask,
|
| 406 |
+
scale=scale,
|
| 407 |
+
)
|
| 408 |
else:
|
| 409 |
+
sdpa_mask = None
|
| 410 |
+
if attention_mask is not None:
|
| 411 |
+
sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
|
| 412 |
+
sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
| 413 |
context_BHLD = F.scaled_dot_product_attention(
|
| 414 |
query_BHLD,
|
| 415 |
key_BHLD,
|
|
|
|
| 574 |
if attention_mask is not None:
|
| 575 |
assert attention_mask.ndim == 2, f"Expected 2D token attention mask, got shape {attention_mask.shape}."
|
| 576 |
token_attention_mask = attention_mask.bool()
|
|
|
|
|
|
|
| 577 |
if self.attn_backend == "flex" and not output_attentions:
|
| 578 |
+
assert create_block_mask is not None, (
|
| 579 |
+
"Flex attention backend requested but torch.create_block_mask is unavailable."
|
| 580 |
+
)
|
| 581 |
flex_block_mask = _create_pad_block_mask(token_attention_mask)
|
| 582 |
+
attention_mask = None
|
| 583 |
else:
|
| 584 |
+
pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
|
| 585 |
+
attention_mask = pairwise_attention_mask.unsqueeze(1)
|
| 586 |
flex_block_mask = None
|
| 587 |
else:
|
| 588 |
flex_block_mask = None
|