lhallee commited on
Commit
94df558
·
verified ·
1 Parent(s): 0050733

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +47 -36
modeling_esm_plusplus.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import torch._inductor.config as inductor_config
3
  import torch._dynamo as dynamo
@@ -429,7 +431,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
429
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
430
  """
431
  from enum import Enum
432
- from typing import Optional
 
433
 
434
  import torch
435
  import torch.nn as nn
@@ -447,7 +450,12 @@ _compiled_flex_attention = None
447
 
448
 
449
  def _get_flex_attention_fn():
450
- """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set."""
 
 
 
 
 
451
  global _compiled_flex_attention
452
  if flex_attention is None:
453
  return None
@@ -455,12 +463,15 @@ def _get_flex_attention_fn():
455
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
456
  return flex_attention
457
  if _compiled_flex_attention is None:
458
- _compiled_flex_attention = torch.compile(flex_attention)
 
 
 
459
  return _compiled_flex_attention
460
 
461
 
462
  ### Kernels Flash Attention Detection
463
- def _infer_kernels_flash_variant(kernel) -> str | None:
464
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
465
  return "flash_attn2"
466
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
@@ -576,7 +587,7 @@ class IndexFirstAxis(torch.autograd.Function):
576
  ).reshape(-1, *other_shape)
577
 
578
  @staticmethod
579
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
580
  (indices,) = ctx.saved_tensors
581
  assert grad_output.ndim >= 2
582
  other_shape = grad_output.shape[1:]
@@ -599,7 +610,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
599
  return output
600
 
601
  @staticmethod
602
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]:
603
  (indices,) = ctx.saved_tensors
604
  return grad_output[indices], None, None
605
 
@@ -618,7 +629,7 @@ def _unpad_input(
618
  key_layer: torch.Tensor,
619
  value_layer: torch.Tensor,
620
  attention_mask_2d: torch.Tensor,
621
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
622
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
623
  seqlens = attention_mask_2d.sum(dim=1).int()
624
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
@@ -634,7 +645,7 @@ def kernels_flash_attention_func(
634
  query_states: torch.Tensor,
635
  key_states: torch.Tensor,
636
  value_states: torch.Tensor,
637
- attention_mask_2d: torch.Tensor | None = None,
638
  causal: bool = False,
639
  ) -> torch.Tensor:
640
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
@@ -707,7 +718,7 @@ def get_attention_mask(
707
  seq_len: int,
708
  device: torch.device,
709
  attention_mask: Optional[torch.Tensor] = None,
710
- ) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]:
711
  """Build padding masks once for all encoder layers.
712
 
713
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
@@ -783,7 +794,7 @@ class ESMplusplusConfig(PretrainedConfig):
783
  num_attention_heads: int = 15,
784
  num_hidden_layers: int = 30,
785
  num_labels: int = 2,
786
- problem_type: str | None = None,
787
  dropout: float = 0.0,
788
  initializer_range: float = 0.02,
789
  attn_backend: str = "sdpa",
@@ -1057,12 +1068,12 @@ class MultiHeadAttention(nn.Module):
1057
  def forward(
1058
  self,
1059
  x: torch.Tensor,
1060
- attention_mask_2d: torch.Tensor | None = None,
1061
- attention_mask_4d: torch.Tensor | None = None,
1062
- flex_block_mask: "BlockMask | None" = None,
1063
  output_attentions: bool = False,
1064
  output_s_max: bool = False,
1065
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
1066
  qkv_BLD3 = self.layernorm_qkv(x)
1067
  query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
1068
  query_BLD, key_BLD = (
@@ -1089,12 +1100,12 @@ class MultiHeadAttention(nn.Module):
1089
  query_BHLD: torch.Tensor,
1090
  key_BHLD: torch.Tensor,
1091
  value_BHLD: torch.Tensor,
1092
- attention_mask_2d: torch.Tensor | None = None,
1093
- attention_mask_4d: torch.Tensor | None = None,
1094
- flex_block_mask: "BlockMask | None" = None,
1095
  output_attentions: bool = False,
1096
  output_s_max: bool = False,
1097
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
1098
  if output_attentions:
1099
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
1100
 
@@ -1111,7 +1122,7 @@ class MultiHeadAttention(nn.Module):
1111
  return attn_output, attn_weights, s_max
1112
 
1113
  @torch.no_grad()
1114
- def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> list[torch.Tensor]:
1115
  q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
1116
  k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
1117
  s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * self.scale
@@ -1122,9 +1133,9 @@ class MultiHeadAttention(nn.Module):
1122
  query_BHLD: torch.Tensor,
1123
  key_BHLD: torch.Tensor,
1124
  value_BHLD: torch.Tensor,
1125
- attention_mask_4d: torch.Tensor | None = None,
1126
  output_s_max: bool = False,
1127
- ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
1128
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * self.scale
1129
  if attention_mask_4d is not None:
1130
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
@@ -1139,8 +1150,8 @@ class MultiHeadAttention(nn.Module):
1139
  query_BHLD: torch.Tensor,
1140
  key_BHLD: torch.Tensor,
1141
  value_BHLD: torch.Tensor,
1142
- attention_mask_2d: torch.Tensor | None = None,
1143
- ) -> tuple[torch.Tensor, None]:
1144
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
1145
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
1146
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
@@ -1155,8 +1166,8 @@ class MultiHeadAttention(nn.Module):
1155
  query_BHLD: torch.Tensor,
1156
  key_BHLD: torch.Tensor,
1157
  value_BHLD: torch.Tensor,
1158
- flex_block_mask: "BlockMask | None" = None,
1159
- ) -> tuple[torch.Tensor, None]:
1160
  assert flex_attention is not None, "Flex attention is not available in this environment."
1161
  fn = _get_flex_attention_fn()
1162
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
@@ -1167,8 +1178,8 @@ class MultiHeadAttention(nn.Module):
1167
  query_BHLD: torch.Tensor,
1168
  key_BHLD: torch.Tensor,
1169
  value_BHLD: torch.Tensor,
1170
- attention_mask_4d: torch.Tensor | None = None,
1171
- ) -> tuple[torch.Tensor, None]:
1172
  context_BHLD = F.scaled_dot_product_attention(
1173
  query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, scale=self.scale,
1174
  )
@@ -1214,12 +1225,12 @@ class UnifiedTransformerBlock(nn.Module):
1214
  def forward(
1215
  self,
1216
  x: torch.Tensor,
1217
- attention_mask_2d: torch.Tensor | None = None,
1218
- attention_mask_4d: torch.Tensor | None = None,
1219
- flex_block_mask: "BlockMask | None" = None,
1220
  output_attentions: bool = False,
1221
  output_s_max: bool = False,
1222
- ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
1223
  attn_output, attn_weights, s_max = self.attn(
1224
  x,
1225
  attention_mask_2d=attention_mask_2d,
@@ -1240,7 +1251,7 @@ class TransformerOutput(ModelOutput):
1240
  last_hidden_state: Optional[torch.Tensor] = None
1241
  hidden_states: Optional[Tuple[torch.Tensor]] = None
1242
  attentions: Optional[Tuple[torch.Tensor]] = None
1243
- s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
1244
 
1245
 
1246
  @dataclass
@@ -1251,7 +1262,7 @@ class ESMplusplusOutput(ModelOutput):
1251
  last_hidden_state: Optional[torch.Tensor] = None
1252
  hidden_states: Optional[Tuple[torch.Tensor]] = None
1253
  attentions: Optional[Tuple[torch.Tensor]] = None
1254
- s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
1255
 
1256
 
1257
  ### Transformer Stack
@@ -1772,7 +1783,7 @@ def get_esmc_checkpoint_path(model: str) -> Path:
1772
  def _load_esmc_checkpoint_model(
1773
  config: ESMplusplusConfig,
1774
  model: str,
1775
- device: torch.device | str = "cpu",
1776
  ) -> ESMplusplusForMaskedLM:
1777
  key = _resolve_esmc_checkpoint_key(model)
1778
  spec = _ESMC_CHECKPOINT_SPECS[key]
@@ -1795,7 +1806,7 @@ def _load_esmc_checkpoint_model(
1795
  return model_obj
1796
 
1797
 
1798
- def ESMplusplus_300M(device: torch.device | str = "cpu"):
1799
  config = ESMplusplusConfig(
1800
  hidden_size=960,
1801
  num_attention_heads=15,
@@ -1804,7 +1815,7 @@ def ESMplusplus_300M(device: torch.device | str = "cpu"):
1804
  return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device)
1805
 
1806
 
1807
- def ESMplusplus_600M(device: torch.device | str = "cpu"):
1808
  config = ESMplusplusConfig(
1809
  hidden_size=1152,
1810
  num_attention_heads=18,
 
1
+ from __future__ import annotations
2
+
3
  import torch
4
  import torch._inductor.config as inductor_config
5
  import torch._dynamo as dynamo
 
431
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
432
  """
433
  from enum import Enum
434
+ from functools import partial
435
+ from typing import Dict, List, Optional, Tuple
436
 
437
  import torch
438
  import torch.nn as nn
 
450
 
451
 
452
  def _get_flex_attention_fn():
453
+ """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
454
+
455
+ Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
456
+ on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
457
+ on older hardware.
458
+ """
459
  global _compiled_flex_attention
460
  if flex_attention is None:
461
  return None
 
463
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
464
  return flex_attention
465
  if _compiled_flex_attention is None:
466
+ _compiled_flex_attention = torch.compile(
467
+ partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
468
+ dynamic=False,
469
+ )
470
  return _compiled_flex_attention
471
 
472
 
473
  ### Kernels Flash Attention Detection
474
+ def _infer_kernels_flash_variant(kernel) -> Optional[str]:
475
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
476
  return "flash_attn2"
477
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
 
587
  ).reshape(-1, *other_shape)
588
 
589
  @staticmethod
590
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
591
  (indices,) = ctx.saved_tensors
592
  assert grad_output.ndim >= 2
593
  other_shape = grad_output.shape[1:]
 
610
  return output
611
 
612
  @staticmethod
613
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
614
  (indices,) = ctx.saved_tensors
615
  return grad_output[indices], None, None
616
 
 
629
  key_layer: torch.Tensor,
630
  value_layer: torch.Tensor,
631
  attention_mask_2d: torch.Tensor,
632
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
633
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
634
  seqlens = attention_mask_2d.sum(dim=1).int()
635
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
 
645
  query_states: torch.Tensor,
646
  key_states: torch.Tensor,
647
  value_states: torch.Tensor,
648
+ attention_mask_2d: Optional[torch.Tensor] = None,
649
  causal: bool = False,
650
  ) -> torch.Tensor:
651
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
 
718
  seq_len: int,
719
  device: torch.device,
720
  attention_mask: Optional[torch.Tensor] = None,
721
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
722
  """Build padding masks once for all encoder layers.
723
 
724
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
 
794
  num_attention_heads: int = 15,
795
  num_hidden_layers: int = 30,
796
  num_labels: int = 2,
797
+ problem_type: Optional[str] = None,
798
  dropout: float = 0.0,
799
  initializer_range: float = 0.02,
800
  attn_backend: str = "sdpa",
 
1068
  def forward(
1069
  self,
1070
  x: torch.Tensor,
1071
+ attention_mask_2d: Optional[torch.Tensor] = None,
1072
+ attention_mask_4d: Optional[torch.Tensor] = None,
1073
+ flex_block_mask: Optional[BlockMask] = None,
1074
  output_attentions: bool = False,
1075
  output_s_max: bool = False,
1076
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1077
  qkv_BLD3 = self.layernorm_qkv(x)
1078
  query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
1079
  query_BLD, key_BLD = (
 
1100
  query_BHLD: torch.Tensor,
1101
  key_BHLD: torch.Tensor,
1102
  value_BHLD: torch.Tensor,
1103
+ attention_mask_2d: Optional[torch.Tensor] = None,
1104
+ attention_mask_4d: Optional[torch.Tensor] = None,
1105
+ flex_block_mask: Optional[BlockMask] = None,
1106
  output_attentions: bool = False,
1107
  output_s_max: bool = False,
1108
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1109
  if output_attentions:
1110
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
1111
 
 
1122
  return attn_output, attn_weights, s_max
1123
 
1124
  @torch.no_grad()
1125
+ def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]:
1126
  q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
1127
  k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
1128
  s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * self.scale
 
1133
  query_BHLD: torch.Tensor,
1134
  key_BHLD: torch.Tensor,
1135
  value_BHLD: torch.Tensor,
1136
+ attention_mask_4d: Optional[torch.Tensor] = None,
1137
  output_s_max: bool = False,
1138
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]:
1139
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * self.scale
1140
  if attention_mask_4d is not None:
1141
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
 
1150
  query_BHLD: torch.Tensor,
1151
  key_BHLD: torch.Tensor,
1152
  value_BHLD: torch.Tensor,
1153
+ attention_mask_2d: Optional[torch.Tensor] = None,
1154
+ ) -> Tuple[torch.Tensor, None]:
1155
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
1156
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
1157
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
 
1166
  query_BHLD: torch.Tensor,
1167
  key_BHLD: torch.Tensor,
1168
  value_BHLD: torch.Tensor,
1169
+ flex_block_mask: Optional[BlockMask] = None,
1170
+ ) -> Tuple[torch.Tensor, None]:
1171
  assert flex_attention is not None, "Flex attention is not available in this environment."
1172
  fn = _get_flex_attention_fn()
1173
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
 
1178
  query_BHLD: torch.Tensor,
1179
  key_BHLD: torch.Tensor,
1180
  value_BHLD: torch.Tensor,
1181
+ attention_mask_4d: Optional[torch.Tensor] = None,
1182
+ ) -> Tuple[torch.Tensor, None]:
1183
  context_BHLD = F.scaled_dot_product_attention(
1184
  query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, scale=self.scale,
1185
  )
 
1225
  def forward(
1226
  self,
1227
  x: torch.Tensor,
1228
+ attention_mask_2d: Optional[torch.Tensor] = None,
1229
+ attention_mask_4d: Optional[torch.Tensor] = None,
1230
+ flex_block_mask: Optional[BlockMask] = None,
1231
  output_attentions: bool = False,
1232
  output_s_max: bool = False,
1233
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
1234
  attn_output, attn_weights, s_max = self.attn(
1235
  x,
1236
  attention_mask_2d=attention_mask_2d,
 
1251
  last_hidden_state: Optional[torch.Tensor] = None
1252
  hidden_states: Optional[Tuple[torch.Tensor]] = None
1253
  attentions: Optional[Tuple[torch.Tensor]] = None
1254
+ s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
1255
 
1256
 
1257
  @dataclass
 
1262
  last_hidden_state: Optional[torch.Tensor] = None
1263
  hidden_states: Optional[Tuple[torch.Tensor]] = None
1264
  attentions: Optional[Tuple[torch.Tensor]] = None
1265
+ s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
1266
 
1267
 
1268
  ### Transformer Stack
 
1783
  def _load_esmc_checkpoint_model(
1784
  config: ESMplusplusConfig,
1785
  model: str,
1786
+ device: Union[torch.device, str] = "cpu",
1787
  ) -> ESMplusplusForMaskedLM:
1788
  key = _resolve_esmc_checkpoint_key(model)
1789
  spec = _ESMC_CHECKPOINT_SPECS[key]
 
1806
  return model_obj
1807
 
1808
 
1809
+ def ESMplusplus_300M(device: Union[torch.device, str] = "cpu"):
1810
  config = ESMplusplusConfig(
1811
  hidden_size=960,
1812
  num_attention_heads=15,
 
1815
  return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device)
1816
 
1817
 
1818
+ def ESMplusplus_600M(device: Union[torch.device, str] = "cpu"):
1819
  config = ESMplusplusConfig(
1820
  hidden_size=1152,
1821
  num_attention_heads=18,