lhallee commited on
Commit
97fb267
·
verified ·
1 Parent(s): d5eccc0

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +9 -17
modeling_esm_plusplus.py CHANGED
@@ -11,7 +11,6 @@ License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-licen
11
  import entrypoint_setup
12
  import math
13
  import os
14
- import warnings
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
@@ -24,10 +23,14 @@ from huggingface_hub import snapshot_download
24
  from tokenizers import Tokenizer
25
  from tokenizers.models import BPE
26
  from tokenizers.processors import TemplateProcessing
27
- from torch.nn.attention.flex_attention import create_block_mask
28
- from torch.nn.attention.flex_attention import flex_attention
29
  from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
30
  from transformers.modeling_outputs import ModelOutput
 
 
 
 
 
 
31
 
32
  try:
33
  # when used from AutoModel, these are in the same directory
@@ -38,6 +41,7 @@ except:
38
 
39
 
40
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
 
41
  token_valid = attention_mask_2d.bool()
42
  batch_size, seq_len = token_valid.shape
43
 
@@ -325,7 +329,6 @@ class MultiHeadAttention(nn.Module):
325
  self.n_heads = n_heads
326
  self.d_head = self.d_model // self.n_heads
327
  self.attn_backend = attn_backend
328
- self._warned_flex_fallback = False
329
  self.layernorm_qkv = nn.Sequential(
330
  nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
331
  )
@@ -388,15 +391,10 @@ class MultiHeadAttention(nn.Module):
388
  flex_dtype_supported = query_BHLD.dtype in (torch.float16, torch.bfloat16)
389
  use_flex = (
390
  self.attn_backend == "flex"
 
391
  and flex_dtype_supported
392
  and (attention_mask is None or flex_block_mask is not None)
393
  )
394
- if self.attn_backend == "flex" and not flex_dtype_supported and not self._warned_flex_fallback:
395
- warnings.warn(
396
- "Flex attention backend requested in float32; falling back to SDPA for strict numerical parity.",
397
- RuntimeWarning,
398
- )
399
- self._warned_flex_fallback = True
400
  if use_flex:
401
  try:
402
  context_BHLD = flex_attention(
@@ -406,13 +404,7 @@ class MultiHeadAttention(nn.Module):
406
  block_mask=flex_block_mask,
407
  scale=scale,
408
  )
409
- except Exception as exc:
410
- if not self._warned_flex_fallback:
411
- warnings.warn(
412
- f"Flex attention failed in ESM++ attention; falling back to SDPA. Error: {exc}",
413
- RuntimeWarning,
414
- )
415
- self._warned_flex_fallback = True
416
  context_BHLD = F.scaled_dot_product_attention(
417
  query_BHLD,
418
  key_BHLD,
 
11
  import entrypoint_setup
12
  import math
13
  import os
 
14
  import torch
15
  import torch.nn as nn
16
  import torch.nn.functional as F
 
23
  from tokenizers import Tokenizer
24
  from tokenizers.models import BPE
25
  from tokenizers.processors import TemplateProcessing
 
 
26
  from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
27
  from transformers.modeling_outputs import ModelOutput
28
+ try:
29
+ from torch.nn.attention.flex_attention import create_block_mask
30
+ from torch.nn.attention.flex_attention import flex_attention
31
+ except ImportError:
32
+ create_block_mask = None
33
+ flex_attention = None
34
 
35
  try:
36
  # when used from AutoModel, these are in the same directory
 
41
 
42
 
43
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
44
+ assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
45
  token_valid = attention_mask_2d.bool()
46
  batch_size, seq_len = token_valid.shape
47
 
 
329
  self.n_heads = n_heads
330
  self.d_head = self.d_model // self.n_heads
331
  self.attn_backend = attn_backend
 
332
  self.layernorm_qkv = nn.Sequential(
333
  nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
334
  )
 
391
  flex_dtype_supported = query_BHLD.dtype in (torch.float16, torch.bfloat16)
392
  use_flex = (
393
  self.attn_backend == "flex"
394
+ and flex_attention is not None
395
  and flex_dtype_supported
396
  and (attention_mask is None or flex_block_mask is not None)
397
  )
 
 
 
 
 
 
398
  if use_flex:
399
  try:
400
  context_BHLD = flex_attention(
 
404
  block_mask=flex_block_mask,
405
  scale=scale,
406
  )
407
+ except Exception:
 
 
 
 
 
 
408
  context_BHLD = F.scaled_dot_product_attention(
409
  query_BHLD,
410
  key_BHLD,