Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +6 -47
modeling_esm_plusplus.py
CHANGED
|
@@ -23,32 +23,15 @@ from huggingface_hub import snapshot_download
|
|
| 23 |
from tokenizers import Tokenizer
|
| 24 |
from tokenizers.models import BPE
|
| 25 |
from tokenizers.processors import TemplateProcessing
|
|
|
|
|
|
|
| 26 |
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
|
| 27 |
from transformers.modeling_outputs import ModelOutput
|
| 28 |
|
| 29 |
from .embedding_mixin import EmbeddingMixin, Pooler
|
| 30 |
|
| 31 |
-
try:
|
| 32 |
-
from torch.nn.attention.flex_attention import create_block_mask
|
| 33 |
-
from torch.nn.attention.flex_attention import flex_attention as _raw_flex_attention
|
| 34 |
-
except ImportError:
|
| 35 |
-
create_block_mask = None
|
| 36 |
-
_raw_flex_attention = None
|
| 37 |
|
| 38 |
-
|
| 39 |
-
def _resolve_flex_attention(attn_compile: bool):
|
| 40 |
-
if _raw_flex_attention is None:
|
| 41 |
-
return None
|
| 42 |
-
if not attn_compile:
|
| 43 |
-
return _raw_flex_attention
|
| 44 |
-
try:
|
| 45 |
-
return torch.compile(_raw_flex_attention, dynamic=True)
|
| 46 |
-
except Exception:
|
| 47 |
-
return _raw_flex_attention
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _create_pad_block_mask(attention_mask_2d: torch.Tensor, block_size: int):
|
| 51 |
-
assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
|
| 52 |
token_valid = attention_mask_2d.bool()
|
| 53 |
batch_size, seq_len = token_valid.shape
|
| 54 |
|
|
@@ -62,7 +45,6 @@ def _create_pad_block_mask(attention_mask_2d: torch.Tensor, block_size: int):
|
|
| 62 |
seq_len,
|
| 63 |
seq_len,
|
| 64 |
device=attention_mask_2d.device,
|
| 65 |
-
BLOCK_SIZE=block_size,
|
| 66 |
)
|
| 67 |
|
| 68 |
|
|
@@ -89,8 +71,6 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
| 89 |
dropout: float = 0.0,
|
| 90 |
initializer_range: float = 0.02,
|
| 91 |
attn_backend: str = "flex",
|
| 92 |
-
attn_compile: bool = True,
|
| 93 |
-
flex_block_size: int = 128,
|
| 94 |
**kwargs,
|
| 95 |
):
|
| 96 |
super().__init__(**kwargs)
|
|
@@ -104,8 +84,6 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
| 104 |
self.initializer_range = initializer_range
|
| 105 |
self.tie_word_embeddings = False
|
| 106 |
self.attn_backend = attn_backend
|
| 107 |
-
self.attn_compile = attn_compile
|
| 108 |
-
self.flex_block_size = flex_block_size
|
| 109 |
|
| 110 |
|
| 111 |
### Rotary Embeddings
|
|
@@ -321,16 +299,12 @@ class MultiHeadAttention(nn.Module):
|
|
| 321 |
d_model: int,
|
| 322 |
n_heads: int,
|
| 323 |
attn_backend: str = "flex",
|
| 324 |
-
attn_compile: bool = True,
|
| 325 |
-
flex_block_size: int = 128,
|
| 326 |
):
|
| 327 |
super().__init__()
|
| 328 |
self.d_model = d_model
|
| 329 |
self.n_heads = n_heads
|
| 330 |
self.d_head = self.d_model // self.n_heads
|
| 331 |
self.attn_backend = attn_backend
|
| 332 |
-
self.flex_block_size = flex_block_size
|
| 333 |
-
self.flex_attention = _resolve_flex_attention(attn_compile)
|
| 334 |
self._warned_flex_fallback = False
|
| 335 |
self.layernorm_qkv = nn.Sequential(
|
| 336 |
nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
|
|
@@ -393,17 +367,15 @@ class MultiHeadAttention(nn.Module):
|
|
| 393 |
sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
| 394 |
use_flex = (
|
| 395 |
self.attn_backend == "flex"
|
| 396 |
-
and self.flex_attention is not None
|
| 397 |
and (attention_mask is None or flex_block_mask is not None)
|
| 398 |
)
|
| 399 |
if use_flex:
|
| 400 |
try:
|
| 401 |
-
context_BHLD =
|
| 402 |
query_BHLD,
|
| 403 |
key_BHLD,
|
| 404 |
value_BHLD,
|
| 405 |
block_mask=flex_block_mask,
|
| 406 |
-
enable_gqa=query_BHLD.shape[1] != key_BHLD.shape[1],
|
| 407 |
)
|
| 408 |
except Exception as exc:
|
| 409 |
if not self._warned_flex_fallback:
|
|
@@ -467,16 +439,12 @@ class UnifiedTransformerBlock(nn.Module):
|
|
| 467 |
expansion_ratio: float = 8 / 3,
|
| 468 |
dropout: float = 0.0,
|
| 469 |
attn_backend: str = "flex",
|
| 470 |
-
attn_compile: bool = True,
|
| 471 |
-
flex_block_size: int = 128,
|
| 472 |
):
|
| 473 |
super().__init__()
|
| 474 |
self.attn = MultiHeadAttention(
|
| 475 |
d_model=d_model,
|
| 476 |
n_heads=n_heads,
|
| 477 |
attn_backend=attn_backend,
|
| 478 |
-
attn_compile=attn_compile,
|
| 479 |
-
flex_block_size=flex_block_size,
|
| 480 |
)
|
| 481 |
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
|
| 482 |
self.scaling_factor = residue_scaling_factor
|
|
@@ -545,12 +513,9 @@ class TransformerStack(nn.Module):
|
|
| 545 |
n_layers: int,
|
| 546 |
dropout: float = 0.0,
|
| 547 |
attn_backend: str = "flex",
|
| 548 |
-
attn_compile: bool = True,
|
| 549 |
-
flex_block_size: int = 128,
|
| 550 |
):
|
| 551 |
super().__init__()
|
| 552 |
self.attn_backend = attn_backend
|
| 553 |
-
self.flex_block_size = flex_block_size
|
| 554 |
self.blocks = nn.ModuleList(
|
| 555 |
[
|
| 556 |
UnifiedTransformerBlock(
|
|
@@ -559,8 +524,6 @@ class TransformerStack(nn.Module):
|
|
| 559 |
residue_scaling_factor=math.sqrt(n_layers / 36),
|
| 560 |
dropout=dropout,
|
| 561 |
attn_backend=attn_backend,
|
| 562 |
-
attn_compile=attn_compile,
|
| 563 |
-
flex_block_size=flex_block_size,
|
| 564 |
)
|
| 565 |
for i in range(n_layers)
|
| 566 |
]
|
|
@@ -591,9 +554,9 @@ class TransformerStack(nn.Module):
|
|
| 591 |
|
| 592 |
if attention_mask is not None:
|
| 593 |
attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
|
| 594 |
-
if self.attn_backend == "flex" and
|
| 595 |
token_attention_mask = attention_mask[:, 0, 0, :]
|
| 596 |
-
flex_block_mask = _create_pad_block_mask(token_attention_mask
|
| 597 |
else:
|
| 598 |
flex_block_mask = None
|
| 599 |
else:
|
|
@@ -677,8 +640,6 @@ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
| 677 |
n_layers=config.num_hidden_layers,
|
| 678 |
dropout=config.dropout,
|
| 679 |
attn_backend=config.attn_backend,
|
| 680 |
-
attn_compile=config.attn_compile,
|
| 681 |
-
flex_block_size=config.flex_block_size,
|
| 682 |
)
|
| 683 |
self.tokenizer = EsmSequenceTokenizer()
|
| 684 |
self.init_weights()
|
|
@@ -739,8 +700,6 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
| 739 |
n_layers=config.num_hidden_layers,
|
| 740 |
dropout=config.dropout,
|
| 741 |
attn_backend=config.attn_backend,
|
| 742 |
-
attn_compile=config.attn_compile,
|
| 743 |
-
flex_block_size=config.flex_block_size,
|
| 744 |
)
|
| 745 |
self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
|
| 746 |
self.ce_loss = nn.CrossEntropyLoss()
|
|
|
|
| 23 |
from tokenizers import Tokenizer
|
| 24 |
from tokenizers.models import BPE
|
| 25 |
from tokenizers.processors import TemplateProcessing
|
| 26 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
| 27 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 28 |
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
|
| 29 |
from transformers.modeling_outputs import ModelOutput
|
| 30 |
|
| 31 |
from .embedding_mixin import EmbeddingMixin, Pooler
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
token_valid = attention_mask_2d.bool()
|
| 36 |
batch_size, seq_len = token_valid.shape
|
| 37 |
|
|
|
|
| 45 |
seq_len,
|
| 46 |
seq_len,
|
| 47 |
device=attention_mask_2d.device,
|
|
|
|
| 48 |
)
|
| 49 |
|
| 50 |
|
|
|
|
| 71 |
dropout: float = 0.0,
|
| 72 |
initializer_range: float = 0.02,
|
| 73 |
attn_backend: str = "flex",
|
|
|
|
|
|
|
| 74 |
**kwargs,
|
| 75 |
):
|
| 76 |
super().__init__(**kwargs)
|
|
|
|
| 84 |
self.initializer_range = initializer_range
|
| 85 |
self.tie_word_embeddings = False
|
| 86 |
self.attn_backend = attn_backend
|
|
|
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
### Rotary Embeddings
|
|
|
|
| 299 |
d_model: int,
|
| 300 |
n_heads: int,
|
| 301 |
attn_backend: str = "flex",
|
|
|
|
|
|
|
| 302 |
):
|
| 303 |
super().__init__()
|
| 304 |
self.d_model = d_model
|
| 305 |
self.n_heads = n_heads
|
| 306 |
self.d_head = self.d_model // self.n_heads
|
| 307 |
self.attn_backend = attn_backend
|
|
|
|
|
|
|
| 308 |
self._warned_flex_fallback = False
|
| 309 |
self.layernorm_qkv = nn.Sequential(
|
| 310 |
nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
|
|
|
|
| 367 |
sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
| 368 |
use_flex = (
|
| 369 |
self.attn_backend == "flex"
|
|
|
|
| 370 |
and (attention_mask is None or flex_block_mask is not None)
|
| 371 |
)
|
| 372 |
if use_flex:
|
| 373 |
try:
|
| 374 |
+
context_BHLD = flex_attention(
|
| 375 |
query_BHLD,
|
| 376 |
key_BHLD,
|
| 377 |
value_BHLD,
|
| 378 |
block_mask=flex_block_mask,
|
|
|
|
| 379 |
)
|
| 380 |
except Exception as exc:
|
| 381 |
if not self._warned_flex_fallback:
|
|
|
|
| 439 |
expansion_ratio: float = 8 / 3,
|
| 440 |
dropout: float = 0.0,
|
| 441 |
attn_backend: str = "flex",
|
|
|
|
|
|
|
| 442 |
):
|
| 443 |
super().__init__()
|
| 444 |
self.attn = MultiHeadAttention(
|
| 445 |
d_model=d_model,
|
| 446 |
n_heads=n_heads,
|
| 447 |
attn_backend=attn_backend,
|
|
|
|
|
|
|
| 448 |
)
|
| 449 |
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
|
| 450 |
self.scaling_factor = residue_scaling_factor
|
|
|
|
| 513 |
n_layers: int,
|
| 514 |
dropout: float = 0.0,
|
| 515 |
attn_backend: str = "flex",
|
|
|
|
|
|
|
| 516 |
):
|
| 517 |
super().__init__()
|
| 518 |
self.attn_backend = attn_backend
|
|
|
|
| 519 |
self.blocks = nn.ModuleList(
|
| 520 |
[
|
| 521 |
UnifiedTransformerBlock(
|
|
|
|
| 524 |
residue_scaling_factor=math.sqrt(n_layers / 36),
|
| 525 |
dropout=dropout,
|
| 526 |
attn_backend=attn_backend,
|
|
|
|
|
|
|
| 527 |
)
|
| 528 |
for i in range(n_layers)
|
| 529 |
]
|
|
|
|
| 554 |
|
| 555 |
if attention_mask is not None:
|
| 556 |
attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
|
| 557 |
+
if self.attn_backend == "flex" and not output_attentions:
|
| 558 |
token_attention_mask = attention_mask[:, 0, 0, :]
|
| 559 |
+
flex_block_mask = _create_pad_block_mask(token_attention_mask)
|
| 560 |
else:
|
| 561 |
flex_block_mask = None
|
| 562 |
else:
|
|
|
|
| 640 |
n_layers=config.num_hidden_layers,
|
| 641 |
dropout=config.dropout,
|
| 642 |
attn_backend=config.attn_backend,
|
|
|
|
|
|
|
| 643 |
)
|
| 644 |
self.tokenizer = EsmSequenceTokenizer()
|
| 645 |
self.init_weights()
|
|
|
|
| 700 |
n_layers=config.num_hidden_layers,
|
| 701 |
dropout=config.dropout,
|
| 702 |
attn_backend=config.attn_backend,
|
|
|
|
|
|
|
| 703 |
)
|
| 704 |
self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
|
| 705 |
self.ce_loss = nn.CrossEntropyLoss()
|