Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +11 -9
modeling_esm_plusplus.py
CHANGED
|
@@ -399,9 +399,9 @@ def get_attention_mask(
|
|
| 399 |
attention_mask: Optional[torch.Tensor] = None
|
| 400 |
) -> torch.Tensor:
|
| 401 |
if attention_mask is None:
|
| 402 |
-
|
| 403 |
else:
|
| 404 |
-
|
| 405 |
|
| 406 |
if attn_backend == "flex":
|
| 407 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
|
@@ -409,8 +409,10 @@ def get_attention_mask(
|
|
| 409 |
if attention_mask is None:
|
| 410 |
flex_block_mask = None
|
| 411 |
else:
|
|
|
|
|
|
|
| 412 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 413 |
-
return (
|
| 414 |
|
| 415 |
flex_block_mask = create_block_mask(
|
| 416 |
mask_mod,
|
|
@@ -420,12 +422,12 @@ def get_attention_mask(
|
|
| 420 |
seq_len,
|
| 421 |
device=device,
|
| 422 |
)
|
| 423 |
-
|
| 424 |
else:
|
| 425 |
flex_block_mask = None
|
| 426 |
-
|
| 427 |
|
| 428 |
-
return
|
| 429 |
|
| 430 |
|
| 431 |
class ESMplusplusConfig(PretrainedConfig):
|
|
@@ -938,7 +940,7 @@ class TransformerStack(nn.Module):
|
|
| 938 |
attentions = () if output_attentions else None
|
| 939 |
|
| 940 |
# move to 4D attention mask or flex block mask
|
| 941 |
-
|
| 942 |
attn_backend=self._attn_backend,
|
| 943 |
batch_size=x.shape[0],
|
| 944 |
seq_len=x.shape[1],
|
|
@@ -951,14 +953,14 @@ class TransformerStack(nn.Module):
|
|
| 951 |
x, attn_weights = self._gradient_checkpointing_func(
|
| 952 |
block.__call__,
|
| 953 |
x=x,
|
| 954 |
-
attention_mask=
|
| 955 |
flex_block_mask=flex_block_mask,
|
| 956 |
output_attentions=output_attentions,
|
| 957 |
)
|
| 958 |
else:
|
| 959 |
x, attn_weights = block(
|
| 960 |
x=x,
|
| 961 |
-
attention_mask=
|
| 962 |
flex_block_mask=flex_block_mask,
|
| 963 |
output_attentions=output_attentions,
|
| 964 |
)
|
|
|
|
| 399 |
attention_mask: Optional[torch.Tensor] = None
|
| 400 |
) -> torch.Tensor:
|
| 401 |
if attention_mask is None:
|
| 402 |
+
attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
|
| 403 |
else:
|
| 404 |
+
attention_mask_2d = attention_mask.bool()
|
| 405 |
|
| 406 |
if attn_backend == "flex":
|
| 407 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
|
|
|
| 409 |
if attention_mask is None:
|
| 410 |
flex_block_mask = None
|
| 411 |
else:
|
| 412 |
+
valid_lens = attention_mask_2d.sum(dim=-1)
|
| 413 |
+
|
| 414 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 415 |
+
return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
|
| 416 |
|
| 417 |
flex_block_mask = create_block_mask(
|
| 418 |
mask_mod,
|
|
|
|
| 422 |
seq_len,
|
| 423 |
device=device,
|
| 424 |
)
|
| 425 |
+
attention_mask_4d = None
|
| 426 |
else:
|
| 427 |
flex_block_mask = None
|
| 428 |
+
attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
|
| 429 |
|
| 430 |
+
return attention_mask_4d, flex_block_mask
|
| 431 |
|
| 432 |
|
| 433 |
class ESMplusplusConfig(PretrainedConfig):
|
|
|
|
| 940 |
attentions = () if output_attentions else None
|
| 941 |
|
| 942 |
# move to 4D attention mask or flex block mask
|
| 943 |
+
attention_mask_4d, flex_block_mask = get_attention_mask(
|
| 944 |
attn_backend=self._attn_backend,
|
| 945 |
batch_size=x.shape[0],
|
| 946 |
seq_len=x.shape[1],
|
|
|
|
| 953 |
x, attn_weights = self._gradient_checkpointing_func(
|
| 954 |
block.__call__,
|
| 955 |
x=x,
|
| 956 |
+
attention_mask=attention_mask_4d,
|
| 957 |
flex_block_mask=flex_block_mask,
|
| 958 |
output_attentions=output_attentions,
|
| 959 |
)
|
| 960 |
else:
|
| 961 |
x, attn_weights = block(
|
| 962 |
x=x,
|
| 963 |
+
attention_mask=attention_mask_4d,
|
| 964 |
flex_block_mask=flex_block_mask,
|
| 965 |
output_attentions=output_attentions,
|
| 966 |
)
|