Upload modeling_dplm.py with huggingface_hub
Browse files- modeling_dplm.py +87 -75
modeling_dplm.py
CHANGED
|
@@ -412,22 +412,38 @@ class BaseSequenceTokenizer:
|
|
| 412 |
raise NotImplementedError
|
| 413 |
|
| 414 |
|
| 415 |
-
def
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
-
return
|
| 424 |
-
mask_mod,
|
| 425 |
-
batch_size,
|
| 426 |
-
1,
|
| 427 |
-
seq_len,
|
| 428 |
-
seq_len,
|
| 429 |
-
device=attention_mask_2d.device,
|
| 430 |
-
)
|
| 431 |
|
| 432 |
|
| 433 |
@dataclass
|
|
@@ -459,11 +475,20 @@ class DPLMPreTrainedModel(EsmPreTrainedModel):
|
|
| 459 |
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
| 460 |
all_tied_weights_keys = {}
|
| 461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
|
| 463 |
class ModifiedEsmSelfAttention(EsmSelfAttention):
|
| 464 |
def __init__(self, config, position_embedding_type=None):
|
| 465 |
super().__init__(config, position_embedding_type)
|
| 466 |
-
self.
|
| 467 |
|
| 468 |
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 469 |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
@@ -473,7 +498,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 473 |
def forward(
|
| 474 |
self,
|
| 475 |
hidden_states: torch.Tensor,
|
| 476 |
-
attention_mask: Optional[torch.
|
| 477 |
head_mask: Optional[torch.FloatTensor] = None,
|
| 478 |
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 479 |
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
@@ -522,24 +547,21 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 522 |
value_layer = value_layer.contiguous()
|
| 523 |
|
| 524 |
if output_attentions:
|
|
|
|
| 525 |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 526 |
-
|
| 527 |
-
attention_scores = attention_scores + attention_mask
|
| 528 |
attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
| 529 |
context_layer = torch.matmul(attention_probs, value_layer)
|
| 530 |
else:
|
| 531 |
attention_probs = None
|
| 532 |
-
if self.attn_backend == "flex":
|
| 533 |
assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
|
| 534 |
assert query_layer.dtype in (torch.float16, torch.bfloat16), (
|
| 535 |
f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
|
| 536 |
)
|
| 537 |
assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention."
|
| 538 |
assert past_key_value is None, "Flex attention backend currently does not support KV caching."
|
| 539 |
-
|
| 540 |
-
assert flex_block_mask is not None, (
|
| 541 |
-
"Flex attention backend requires a block mask when attention_mask is provided."
|
| 542 |
-
)
|
| 543 |
context_layer = flex_attention(
|
| 544 |
query_layer,
|
| 545 |
key_layer,
|
|
@@ -579,14 +601,14 @@ class ModifiedEsmAttention(EsmAttention):
|
|
| 579 |
|
| 580 |
def forward(
|
| 581 |
self,
|
| 582 |
-
hidden_states,
|
| 583 |
-
attention_mask
|
| 584 |
-
head_mask=None,
|
| 585 |
-
encoder_hidden_states=None,
|
| 586 |
-
encoder_attention_mask=None,
|
| 587 |
-
past_key_value=None,
|
| 588 |
-
output_attentions=False,
|
| 589 |
-
flex_block_mask=None,
|
| 590 |
):
|
| 591 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 592 |
self_outputs = self.self(
|
|
@@ -622,14 +644,14 @@ class ModifiedEsmLayer(EsmLayer):
|
|
| 622 |
|
| 623 |
def forward(
|
| 624 |
self,
|
| 625 |
-
hidden_states,
|
| 626 |
-
attention_mask
|
| 627 |
-
head_mask=None,
|
| 628 |
-
encoder_hidden_states=None,
|
| 629 |
-
encoder_attention_mask=None,
|
| 630 |
-
past_key_value=None,
|
| 631 |
-
output_attentions=False,
|
| 632 |
-
flex_block_mask=None,
|
| 633 |
):
|
| 634 |
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 635 |
self_attention_outputs = self.attention(
|
|
@@ -688,17 +710,17 @@ class ModifiedEsmEncoder(EsmEncoder):
|
|
| 688 |
|
| 689 |
def forward(
|
| 690 |
self,
|
| 691 |
-
hidden_states,
|
| 692 |
-
attention_mask
|
| 693 |
-
head_mask=None,
|
| 694 |
-
encoder_hidden_states=None,
|
| 695 |
-
encoder_attention_mask=None,
|
| 696 |
-
past_key_values=None,
|
| 697 |
-
use_cache=None,
|
| 698 |
-
output_attentions=False,
|
| 699 |
-
output_hidden_states=False,
|
| 700 |
-
return_dict=True,
|
| 701 |
-
flex_block_mask=None,
|
| 702 |
):
|
| 703 |
all_hidden_states = () if output_hidden_states else None
|
| 704 |
all_self_attentions = () if output_attentions else None
|
|
@@ -873,22 +895,12 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 873 |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 874 |
|
| 875 |
if attention_mask is None:
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
token_attention_mask = None
|
| 879 |
-
if attention_mask.dim() == 2:
|
| 880 |
token_attention_mask = attention_mask.bool()
|
| 881 |
-
if self.config.attn_backend == "flex" and output_attentions is False:
|
| 882 |
-
extended_attention_mask = None
|
| 883 |
-
else:
|
| 884 |
-
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 885 |
elif attention_mask.dim() == 4:
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
else:
|
| 889 |
-
extended_attention_mask = attention_mask
|
| 890 |
-
if input_ids is not None:
|
| 891 |
-
token_attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 892 |
else:
|
| 893 |
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
|
| 894 |
|
|
@@ -907,16 +919,16 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 907 |
if embedding_attention_mask is None and input_ids is not None:
|
| 908 |
embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 909 |
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
|
| 921 |
embedding_output = self.embeddings(
|
| 922 |
input_ids=input_ids,
|
|
|
|
| 412 |
raise NotImplementedError
|
| 413 |
|
| 414 |
|
| 415 |
+
def get_attention_mask(
|
| 416 |
+
attn_backend: str,
|
| 417 |
+
batch_size: int,
|
| 418 |
+
seq_len: int,
|
| 419 |
+
device: torch.device,
|
| 420 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 421 |
+
) -> Tuple[Optional[torch.Tensor], Optional[object]]:
|
| 422 |
+
if attention_mask is None:
|
| 423 |
+
token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool()
|
| 424 |
+
else:
|
| 425 |
+
token_attention_mask = attention_mask.bool()
|
| 426 |
+
|
| 427 |
+
if attn_backend == "flex":
|
| 428 |
+
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
| 429 |
+
|
| 430 |
+
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 431 |
+
return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx]
|
| 432 |
+
|
| 433 |
+
flex_block_mask = create_block_mask(
|
| 434 |
+
mask_mod,
|
| 435 |
+
batch_size,
|
| 436 |
+
1,
|
| 437 |
+
seq_len,
|
| 438 |
+
seq_len,
|
| 439 |
+
device=device,
|
| 440 |
+
)
|
| 441 |
+
extended_attention_mask = None
|
| 442 |
+
else:
|
| 443 |
+
flex_block_mask = None
|
| 444 |
+
extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :]
|
| 445 |
|
| 446 |
+
return extended_attention_mask, flex_block_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
|
| 449 |
@dataclass
|
|
|
|
| 475 |
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
| 476 |
all_tied_weights_keys = {}
|
| 477 |
|
| 478 |
+
@property
|
| 479 |
+
def attn_backend(self) -> str:
|
| 480 |
+
return self.config.attn_backend
|
| 481 |
+
|
| 482 |
+
@attn_backend.setter
|
| 483 |
+
def attn_backend(self, backend: str) -> None:
|
| 484 |
+
assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
|
| 485 |
+
self.config.attn_backend = backend
|
| 486 |
+
|
| 487 |
|
| 488 |
class ModifiedEsmSelfAttention(EsmSelfAttention):
|
| 489 |
def __init__(self, config, position_embedding_type=None):
|
| 490 |
super().__init__(config, position_embedding_type)
|
| 491 |
+
self.config = config
|
| 492 |
|
| 493 |
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 494 |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
|
|
| 498 |
def forward(
|
| 499 |
self,
|
| 500 |
hidden_states: torch.Tensor,
|
| 501 |
+
attention_mask: Optional[torch.Tensor],
|
| 502 |
head_mask: Optional[torch.FloatTensor] = None,
|
| 503 |
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 504 |
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
| 547 |
value_layer = value_layer.contiguous()
|
| 548 |
|
| 549 |
if output_attentions:
|
| 550 |
+
assert attention_mask is not None, "output_attentions=True requires a concrete attention mask."
|
| 551 |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 552 |
+
attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
|
|
|
|
| 553 |
attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
| 554 |
context_layer = torch.matmul(attention_probs, value_layer)
|
| 555 |
else:
|
| 556 |
attention_probs = None
|
| 557 |
+
if self.config.attn_backend == "flex":
|
| 558 |
assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
|
| 559 |
assert query_layer.dtype in (torch.float16, torch.bfloat16), (
|
| 560 |
f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
|
| 561 |
)
|
| 562 |
assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention."
|
| 563 |
assert past_key_value is None, "Flex attention backend currently does not support KV caching."
|
| 564 |
+
assert flex_block_mask is not None, "Flex attention backend requires a block mask."
|
|
|
|
|
|
|
|
|
|
| 565 |
context_layer = flex_attention(
|
| 566 |
query_layer,
|
| 567 |
key_layer,
|
|
|
|
| 601 |
|
| 602 |
def forward(
|
| 603 |
self,
|
| 604 |
+
hidden_states: torch.Tensor,
|
| 605 |
+
attention_mask: Optional[torch.Tensor],
|
| 606 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 607 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 608 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 609 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 610 |
+
output_attentions: bool = False,
|
| 611 |
+
flex_block_mask: Optional[object] = None,
|
| 612 |
):
|
| 613 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 614 |
self_outputs = self.self(
|
|
|
|
| 644 |
|
| 645 |
def forward(
|
| 646 |
self,
|
| 647 |
+
hidden_states: torch.Tensor,
|
| 648 |
+
attention_mask: Optional[torch.Tensor],
|
| 649 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 650 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 651 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 652 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 653 |
+
output_attentions: bool = False,
|
| 654 |
+
flex_block_mask: Optional[object] = None,
|
| 655 |
):
|
| 656 |
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 657 |
self_attention_outputs = self.attention(
|
|
|
|
| 710 |
|
| 711 |
def forward(
|
| 712 |
self,
|
| 713 |
+
hidden_states: torch.Tensor,
|
| 714 |
+
attention_mask: Optional[torch.Tensor],
|
| 715 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 716 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 717 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 718 |
+
past_key_values: Optional[List[Tuple[Tuple[torch.FloatTensor]]]] = None,
|
| 719 |
+
use_cache: Optional[bool] = None,
|
| 720 |
+
output_attentions: bool = False,
|
| 721 |
+
output_hidden_states: bool = False,
|
| 722 |
+
return_dict: bool = True,
|
| 723 |
+
flex_block_mask: Optional[object] = None,
|
| 724 |
):
|
| 725 |
all_hidden_states = () if output_hidden_states else None
|
| 726 |
all_self_attentions = () if output_attentions else None
|
|
|
|
| 895 |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 896 |
|
| 897 |
if attention_mask is None:
|
| 898 |
+
token_attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
|
| 899 |
+
elif attention_mask.dim() == 2:
|
|
|
|
|
|
|
| 900 |
token_attention_mask = attention_mask.bool()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
elif attention_mask.dim() == 4:
|
| 902 |
+
assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
|
| 903 |
+
token_attention_mask = input_ids.ne(self.config.pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
else:
|
| 905 |
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
|
| 906 |
|
|
|
|
| 919 |
if embedding_attention_mask is None and input_ids is not None:
|
| 920 |
embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 921 |
|
| 922 |
+
if self.config.attn_backend == "flex" and output_attentions:
|
| 923 |
+
raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
|
| 924 |
+
|
| 925 |
+
extended_attention_mask, flex_block_mask = get_attention_mask(
|
| 926 |
+
attn_backend=self.config.attn_backend,
|
| 927 |
+
batch_size=batch_size,
|
| 928 |
+
seq_len=seq_length,
|
| 929 |
+
device=device,
|
| 930 |
+
attention_mask=token_attention_mask,
|
| 931 |
+
)
|
| 932 |
|
| 933 |
embedding_output = self.embeddings(
|
| 934 |
input_ids=input_ids,
|