Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +11 -2
modeling_fastesm.py
CHANGED
|
@@ -508,7 +508,7 @@ class EsmSelfAttention(nn.Module):
|
|
| 508 |
self.scale = self.attention_head_size**-0.5
|
| 509 |
|
| 510 |
self.dropout_prob = config.attention_probs_dropout_prob
|
| 511 |
-
self.
|
| 512 |
self.position_embedding_type = position_embedding_type or getattr(
|
| 513 |
config, "position_embedding_type", "absolute"
|
| 514 |
)
|
|
@@ -555,7 +555,7 @@ class EsmSelfAttention(nn.Module):
|
|
| 555 |
context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
|
| 556 |
return context_layer, attention_probs
|
| 557 |
else:
|
| 558 |
-
if self.attn_backend == "flex":
|
| 559 |
assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
|
| 560 |
assert query_layer.dtype in (torch.float16, torch.bfloat16), f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
|
| 561 |
assert flex_block_mask is not None, "Flex attention backend requires a block mask"
|
|
@@ -771,6 +771,15 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 771 |
# See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
|
| 772 |
return None
|
| 773 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
|
| 775 |
class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 776 |
def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
|
|
|
|
| 508 |
self.scale = self.attention_head_size**-0.5
|
| 509 |
|
| 510 |
self.dropout_prob = config.attention_probs_dropout_prob
|
| 511 |
+
self.config = config
|
| 512 |
self.position_embedding_type = position_embedding_type or getattr(
|
| 513 |
config, "position_embedding_type", "absolute"
|
| 514 |
)
|
|
|
|
| 555 |
context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
|
| 556 |
return context_layer, attention_probs
|
| 557 |
else:
|
| 558 |
+
if self.config.attn_backend == "flex":
|
| 559 |
assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
|
| 560 |
assert query_layer.dtype in (torch.float16, torch.bfloat16), f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
|
| 561 |
assert flex_block_mask is not None, "Flex attention backend requires a block mask"
|
|
|
|
| 771 |
# See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
|
| 772 |
return None
|
| 773 |
|
| 774 |
+
@property
|
| 775 |
+
def attn_backend(self) -> str:
|
| 776 |
+
return self.config.attn_backend
|
| 777 |
+
|
| 778 |
+
@attn_backend.setter
|
| 779 |
+
def attn_backend(self, backend: str) -> None:
|
| 780 |
+
assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
|
| 781 |
+
self.config.attn_backend = backend
|
| 782 |
+
|
| 783 |
|
| 784 |
class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 785 |
def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
|