Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +11 -3
modeling_esm_plusplus.py
CHANGED
|
@@ -29,7 +29,12 @@ from torch.nn.attention.flex_attention import flex_attention
|
|
| 29 |
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
|
| 30 |
from transformers.modeling_outputs import ModelOutput
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
|
|
@@ -350,10 +355,10 @@ class MultiHeadAttention(nn.Module):
|
|
| 350 |
)
|
| 351 |
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
|
| 352 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
|
|
|
| 353 |
|
| 354 |
if output_attentions: # Manual attention computation
|
| 355 |
-
b, h, l,
|
| 356 |
-
scale = 1 / math.sqrt(d)
|
| 357 |
attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
|
| 358 |
if attention_mask is not None:
|
| 359 |
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
|
@@ -377,6 +382,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 377 |
key_BHLD,
|
| 378 |
value_BHLD,
|
| 379 |
block_mask=flex_block_mask,
|
|
|
|
| 380 |
)
|
| 381 |
except Exception as exc:
|
| 382 |
if not self._warned_flex_fallback:
|
|
@@ -390,6 +396,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 390 |
key_BHLD,
|
| 391 |
value_BHLD,
|
| 392 |
attn_mask=sdpa_mask,
|
|
|
|
| 393 |
)
|
| 394 |
else:
|
| 395 |
context_BHLD = F.scaled_dot_product_attention(
|
|
@@ -397,6 +404,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 397 |
key_BHLD,
|
| 398 |
value_BHLD,
|
| 399 |
attn_mask=sdpa_mask,
|
|
|
|
| 400 |
)
|
| 401 |
|
| 402 |
context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
|
|
|
|
| 29 |
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
|
| 30 |
from transformers.modeling_outputs import ModelOutput
|
| 31 |
|
| 32 |
+
try:
|
| 33 |
+
# when used from AutoModel, these are in the same directory
|
| 34 |
+
from .embedding_mixin import EmbeddingMixin, Pooler
|
| 35 |
+
except:
|
| 36 |
+
# when running from our repo, these are in the base directory
|
| 37 |
+
from embedding_mixin import EmbeddingMixin, Pooler
|
| 38 |
|
| 39 |
|
| 40 |
def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
|
|
|
|
| 355 |
)
|
| 356 |
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
|
| 357 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
| 358 |
+
scale = 1 / math.sqrt(self.d_head)
|
| 359 |
|
| 360 |
if output_attentions: # Manual attention computation
|
| 361 |
+
b, h, l, _ = query_BHLD.shape
|
|
|
|
| 362 |
attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
|
| 363 |
if attention_mask is not None:
|
| 364 |
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
|
|
|
| 382 |
key_BHLD,
|
| 383 |
value_BHLD,
|
| 384 |
block_mask=flex_block_mask,
|
| 385 |
+
scale=scale,
|
| 386 |
)
|
| 387 |
except Exception as exc:
|
| 388 |
if not self._warned_flex_fallback:
|
|
|
|
| 396 |
key_BHLD,
|
| 397 |
value_BHLD,
|
| 398 |
attn_mask=sdpa_mask,
|
| 399 |
+
scale=scale,
|
| 400 |
)
|
| 401 |
else:
|
| 402 |
context_BHLD = F.scaled_dot_product_attention(
|
|
|
|
| 404 |
key_BHLD,
|
| 405 |
value_BHLD,
|
| 406 |
attn_mask=sdpa_mask,
|
| 407 |
+
scale=scale,
|
| 408 |
)
|
| 409 |
|
| 410 |
context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
|