lhallee commited on
Commit
bca00ae
·
verified ·
1 Parent(s): a7e2c29

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +6 -47
modeling_esm_plusplus.py CHANGED
@@ -23,32 +23,15 @@ from huggingface_hub import snapshot_download
23
  from tokenizers import Tokenizer
24
  from tokenizers.models import BPE
25
  from tokenizers.processors import TemplateProcessing
 
 
26
  from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
27
  from transformers.modeling_outputs import ModelOutput
28
 
29
  from .embedding_mixin import EmbeddingMixin, Pooler
30
 
31
- try:
32
- from torch.nn.attention.flex_attention import create_block_mask
33
- from torch.nn.attention.flex_attention import flex_attention as _raw_flex_attention
34
- except ImportError:
35
- create_block_mask = None
36
- _raw_flex_attention = None
37
 
38
-
39
- def _resolve_flex_attention(attn_compile: bool):
40
- if _raw_flex_attention is None:
41
- return None
42
- if not attn_compile:
43
- return _raw_flex_attention
44
- try:
45
- return torch.compile(_raw_flex_attention, dynamic=True)
46
- except Exception:
47
- return _raw_flex_attention
48
-
49
-
50
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor, block_size: int):
51
- assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
52
  token_valid = attention_mask_2d.bool()
53
  batch_size, seq_len = token_valid.shape
54
 
@@ -62,7 +45,6 @@ def _create_pad_block_mask(attention_mask_2d: torch.Tensor, block_size: int):
62
  seq_len,
63
  seq_len,
64
  device=attention_mask_2d.device,
65
- BLOCK_SIZE=block_size,
66
  )
67
 
68
 
@@ -89,8 +71,6 @@ class ESMplusplusConfig(PretrainedConfig):
89
  dropout: float = 0.0,
90
  initializer_range: float = 0.02,
91
  attn_backend: str = "flex",
92
- attn_compile: bool = True,
93
- flex_block_size: int = 128,
94
  **kwargs,
95
  ):
96
  super().__init__(**kwargs)
@@ -104,8 +84,6 @@ class ESMplusplusConfig(PretrainedConfig):
104
  self.initializer_range = initializer_range
105
  self.tie_word_embeddings = False
106
  self.attn_backend = attn_backend
107
- self.attn_compile = attn_compile
108
- self.flex_block_size = flex_block_size
109
 
110
 
111
  ### Rotary Embeddings
@@ -321,16 +299,12 @@ class MultiHeadAttention(nn.Module):
321
  d_model: int,
322
  n_heads: int,
323
  attn_backend: str = "flex",
324
- attn_compile: bool = True,
325
- flex_block_size: int = 128,
326
  ):
327
  super().__init__()
328
  self.d_model = d_model
329
  self.n_heads = n_heads
330
  self.d_head = self.d_model // self.n_heads
331
  self.attn_backend = attn_backend
332
- self.flex_block_size = flex_block_size
333
- self.flex_attention = _resolve_flex_attention(attn_compile)
334
  self._warned_flex_fallback = False
335
  self.layernorm_qkv = nn.Sequential(
336
  nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
@@ -393,17 +367,15 @@ class MultiHeadAttention(nn.Module):
393
  sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
394
  use_flex = (
395
  self.attn_backend == "flex"
396
- and self.flex_attention is not None
397
  and (attention_mask is None or flex_block_mask is not None)
398
  )
399
  if use_flex:
400
  try:
401
- context_BHLD = self.flex_attention(
402
  query_BHLD,
403
  key_BHLD,
404
  value_BHLD,
405
  block_mask=flex_block_mask,
406
- enable_gqa=query_BHLD.shape[1] != key_BHLD.shape[1],
407
  )
408
  except Exception as exc:
409
  if not self._warned_flex_fallback:
@@ -467,16 +439,12 @@ class UnifiedTransformerBlock(nn.Module):
467
  expansion_ratio: float = 8 / 3,
468
  dropout: float = 0.0,
469
  attn_backend: str = "flex",
470
- attn_compile: bool = True,
471
- flex_block_size: int = 128,
472
  ):
473
  super().__init__()
474
  self.attn = MultiHeadAttention(
475
  d_model=d_model,
476
  n_heads=n_heads,
477
  attn_backend=attn_backend,
478
- attn_compile=attn_compile,
479
- flex_block_size=flex_block_size,
480
  )
481
  self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
482
  self.scaling_factor = residue_scaling_factor
@@ -545,12 +513,9 @@ class TransformerStack(nn.Module):
545
  n_layers: int,
546
  dropout: float = 0.0,
547
  attn_backend: str = "flex",
548
- attn_compile: bool = True,
549
- flex_block_size: int = 128,
550
  ):
551
  super().__init__()
552
  self.attn_backend = attn_backend
553
- self.flex_block_size = flex_block_size
554
  self.blocks = nn.ModuleList(
555
  [
556
  UnifiedTransformerBlock(
@@ -559,8 +524,6 @@ class TransformerStack(nn.Module):
559
  residue_scaling_factor=math.sqrt(n_layers / 36),
560
  dropout=dropout,
561
  attn_backend=attn_backend,
562
- attn_compile=attn_compile,
563
- flex_block_size=flex_block_size,
564
  )
565
  for i in range(n_layers)
566
  ]
@@ -591,9 +554,9 @@ class TransformerStack(nn.Module):
591
 
592
  if attention_mask is not None:
593
  attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
594
- if self.attn_backend == "flex" and create_block_mask is not None and not output_attentions:
595
  token_attention_mask = attention_mask[:, 0, 0, :]
596
- flex_block_mask = _create_pad_block_mask(token_attention_mask, self.flex_block_size)
597
  else:
598
  flex_block_mask = None
599
  else:
@@ -677,8 +640,6 @@ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
677
  n_layers=config.num_hidden_layers,
678
  dropout=config.dropout,
679
  attn_backend=config.attn_backend,
680
- attn_compile=config.attn_compile,
681
- flex_block_size=config.flex_block_size,
682
  )
683
  self.tokenizer = EsmSequenceTokenizer()
684
  self.init_weights()
@@ -739,8 +700,6 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
739
  n_layers=config.num_hidden_layers,
740
  dropout=config.dropout,
741
  attn_backend=config.attn_backend,
742
- attn_compile=config.attn_compile,
743
- flex_block_size=config.flex_block_size,
744
  )
745
  self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
746
  self.ce_loss = nn.CrossEntropyLoss()
 
23
  from tokenizers import Tokenizer
24
  from tokenizers.models import BPE
25
  from tokenizers.processors import TemplateProcessing
26
+ from torch.nn.attention.flex_attention import create_block_mask
27
+ from torch.nn.attention.flex_attention import flex_attention
28
  from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
29
  from transformers.modeling_outputs import ModelOutput
30
 
31
  from .embedding_mixin import EmbeddingMixin, Pooler
32
 
 
 
 
 
 
 
33
 
34
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  token_valid = attention_mask_2d.bool()
36
  batch_size, seq_len = token_valid.shape
37
 
 
45
  seq_len,
46
  seq_len,
47
  device=attention_mask_2d.device,
 
48
  )
49
 
50
 
 
71
  dropout: float = 0.0,
72
  initializer_range: float = 0.02,
73
  attn_backend: str = "flex",
 
 
74
  **kwargs,
75
  ):
76
  super().__init__(**kwargs)
 
84
  self.initializer_range = initializer_range
85
  self.tie_word_embeddings = False
86
  self.attn_backend = attn_backend
 
 
87
 
88
 
89
  ### Rotary Embeddings
 
299
  d_model: int,
300
  n_heads: int,
301
  attn_backend: str = "flex",
 
 
302
  ):
303
  super().__init__()
304
  self.d_model = d_model
305
  self.n_heads = n_heads
306
  self.d_head = self.d_model // self.n_heads
307
  self.attn_backend = attn_backend
 
 
308
  self._warned_flex_fallback = False
309
  self.layernorm_qkv = nn.Sequential(
310
  nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
 
367
  sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
368
  use_flex = (
369
  self.attn_backend == "flex"
 
370
  and (attention_mask is None or flex_block_mask is not None)
371
  )
372
  if use_flex:
373
  try:
374
+ context_BHLD = flex_attention(
375
  query_BHLD,
376
  key_BHLD,
377
  value_BHLD,
378
  block_mask=flex_block_mask,
 
379
  )
380
  except Exception as exc:
381
  if not self._warned_flex_fallback:
 
439
  expansion_ratio: float = 8 / 3,
440
  dropout: float = 0.0,
441
  attn_backend: str = "flex",
 
 
442
  ):
443
  super().__init__()
444
  self.attn = MultiHeadAttention(
445
  d_model=d_model,
446
  n_heads=n_heads,
447
  attn_backend=attn_backend,
 
 
448
  )
449
  self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
450
  self.scaling_factor = residue_scaling_factor
 
513
  n_layers: int,
514
  dropout: float = 0.0,
515
  attn_backend: str = "flex",
 
 
516
  ):
517
  super().__init__()
518
  self.attn_backend = attn_backend
 
519
  self.blocks = nn.ModuleList(
520
  [
521
  UnifiedTransformerBlock(
 
524
  residue_scaling_factor=math.sqrt(n_layers / 36),
525
  dropout=dropout,
526
  attn_backend=attn_backend,
 
 
527
  )
528
  for i in range(n_layers)
529
  ]
 
554
 
555
  if attention_mask is not None:
556
  attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
557
+ if self.attn_backend == "flex" and not output_attentions:
558
  token_attention_mask = attention_mask[:, 0, 0, :]
559
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
560
  else:
561
  flex_block_mask = None
562
  else:
 
640
  n_layers=config.num_hidden_layers,
641
  dropout=config.dropout,
642
  attn_backend=config.attn_backend,
 
 
643
  )
644
  self.tokenizer = EsmSequenceTokenizer()
645
  self.init_weights()
 
700
  n_layers=config.num_hidden_layers,
701
  dropout=config.dropout,
702
  attn_backend=config.attn_backend,
 
 
703
  )
704
  self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
705
  self.ce_loss = nn.CrossEntropyLoss()