Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +25 -21
modeling_fastesm.py
CHANGED
|
@@ -364,7 +364,6 @@ from typing import Optional, Tuple, Union, Dict, Any
|
|
| 364 |
from einops import rearrange
|
| 365 |
from dataclasses import dataclass
|
| 366 |
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
| 367 |
-
from transformers import initialization as init
|
| 368 |
from transformers.modeling_outputs import (
|
| 369 |
ModelOutput,
|
| 370 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
@@ -399,9 +398,9 @@ def get_attention_mask(
|
|
| 399 |
attention_mask: Optional[torch.Tensor] = None
|
| 400 |
) -> torch.Tensor:
|
| 401 |
if attention_mask is None:
|
| 402 |
-
|
| 403 |
else:
|
| 404 |
-
|
| 405 |
|
| 406 |
if attn_backend == "flex":
|
| 407 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
|
@@ -409,8 +408,10 @@ def get_attention_mask(
|
|
| 409 |
if attention_mask is None:
|
| 410 |
flex_block_mask = None
|
| 411 |
else:
|
|
|
|
|
|
|
| 412 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 413 |
-
return (
|
| 414 |
|
| 415 |
flex_block_mask = create_block_mask(
|
| 416 |
mask_mod,
|
|
@@ -420,12 +421,12 @@ def get_attention_mask(
|
|
| 420 |
seq_len,
|
| 421 |
device=device,
|
| 422 |
)
|
| 423 |
-
|
| 424 |
else:
|
| 425 |
flex_block_mask = None
|
| 426 |
-
|
| 427 |
|
| 428 |
-
return
|
| 429 |
|
| 430 |
|
| 431 |
@dataclass
|
|
@@ -763,16 +764,19 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 763 |
return True
|
| 764 |
|
| 765 |
@torch.no_grad()
|
| 766 |
-
def _init_weights(self, module):
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
|
|
|
|
|
|
|
|
|
| 776 |
|
| 777 |
def get_output_embeddings(self):
|
| 778 |
# NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
|
|
@@ -809,7 +813,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 809 |
|
| 810 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 811 |
token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
|
| 812 |
-
|
| 813 |
attn_backend=self.config.attn_backend,
|
| 814 |
batch_size=input_ids.shape[0],
|
| 815 |
seq_len=input_ids.shape[1],
|
|
@@ -818,7 +822,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 818 |
)
|
| 819 |
encoder_outputs = self.encoder(
|
| 820 |
token_embedding_output,
|
| 821 |
-
attention_mask=
|
| 822 |
flex_block_mask=flex_block_mask,
|
| 823 |
output_hidden_states=False,
|
| 824 |
output_attentions=False,
|
|
@@ -874,7 +878,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 874 |
attention_mask=attention_mask,
|
| 875 |
inputs_embeds=inputs_embeds,
|
| 876 |
)
|
| 877 |
-
|
| 878 |
attn_backend=self.config.attn_backend,
|
| 879 |
batch_size=input_ids.shape[0],
|
| 880 |
seq_len=input_ids.shape[1],
|
|
@@ -883,7 +887,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 883 |
)
|
| 884 |
encoder_outputs = self.encoder(
|
| 885 |
token_embedding_output,
|
| 886 |
-
attention_mask=
|
| 887 |
flex_block_mask=flex_block_mask,
|
| 888 |
output_hidden_states=output_hidden_states,
|
| 889 |
output_attentions=output_attentions,
|
|
|
|
| 364 |
from einops import rearrange
|
| 365 |
from dataclasses import dataclass
|
| 366 |
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
|
|
|
| 367 |
from transformers.modeling_outputs import (
|
| 368 |
ModelOutput,
|
| 369 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
|
|
| 398 |
attention_mask: Optional[torch.Tensor] = None
|
| 399 |
) -> torch.Tensor:
|
| 400 |
if attention_mask is None:
|
| 401 |
+
attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
|
| 402 |
else:
|
| 403 |
+
attention_mask_2d = attention_mask.bool()
|
| 404 |
|
| 405 |
if attn_backend == "flex":
|
| 406 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
|
|
|
| 408 |
if attention_mask is None:
|
| 409 |
flex_block_mask = None
|
| 410 |
else:
|
| 411 |
+
valid_lens = attention_mask_2d.sum(dim=-1)
|
| 412 |
+
|
| 413 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 414 |
+
return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
|
| 415 |
|
| 416 |
flex_block_mask = create_block_mask(
|
| 417 |
mask_mod,
|
|
|
|
| 421 |
seq_len,
|
| 422 |
device=device,
|
| 423 |
)
|
| 424 |
+
attention_mask_4d = None
|
| 425 |
else:
|
| 426 |
flex_block_mask = None
|
| 427 |
+
attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
|
| 428 |
|
| 429 |
+
return attention_mask_4d, flex_block_mask
|
| 430 |
|
| 431 |
|
| 432 |
@dataclass
|
|
|
|
| 764 |
return True
|
| 765 |
|
| 766 |
@torch.no_grad()
|
| 767 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 768 |
+
std = self.config.initializer_range
|
| 769 |
+
if isinstance(module, nn.Linear):
|
| 770 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 771 |
+
if module.bias is not None:
|
| 772 |
+
module.bias.data.zero_()
|
| 773 |
+
elif isinstance(module, nn.Embedding):
|
| 774 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 775 |
+
if module.padding_idx is not None:
|
| 776 |
+
module.weight.data[module.padding_idx].zero_()
|
| 777 |
+
|
| 778 |
+
def post_init(self) -> None:
|
| 779 |
+
super().post_init()
|
| 780 |
|
| 781 |
def get_output_embeddings(self):
|
| 782 |
# NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
|
|
|
|
| 813 |
|
| 814 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 815 |
token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
|
| 816 |
+
attention_mask_4d, flex_block_mask = get_attention_mask(
|
| 817 |
attn_backend=self.config.attn_backend,
|
| 818 |
batch_size=input_ids.shape[0],
|
| 819 |
seq_len=input_ids.shape[1],
|
|
|
|
| 822 |
)
|
| 823 |
encoder_outputs = self.encoder(
|
| 824 |
token_embedding_output,
|
| 825 |
+
attention_mask=attention_mask_4d,
|
| 826 |
flex_block_mask=flex_block_mask,
|
| 827 |
output_hidden_states=False,
|
| 828 |
output_attentions=False,
|
|
|
|
| 878 |
attention_mask=attention_mask,
|
| 879 |
inputs_embeds=inputs_embeds,
|
| 880 |
)
|
| 881 |
+
attention_mask_4d, flex_block_mask = get_attention_mask(
|
| 882 |
attn_backend=self.config.attn_backend,
|
| 883 |
batch_size=input_ids.shape[0],
|
| 884 |
seq_len=input_ids.shape[1],
|
|
|
|
| 887 |
)
|
| 888 |
encoder_outputs = self.encoder(
|
| 889 |
token_embedding_output,
|
| 890 |
+
attention_mask=attention_mask_4d,
|
| 891 |
flex_block_mask=flex_block_mask,
|
| 892 |
output_hidden_states=output_hidden_states,
|
| 893 |
output_attentions=output_attentions,
|