lhallee commited on
Commit
94a1af4
·
verified ·
1 Parent(s): d0455eb

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +32 -31
modeling_esm_plusplus.py CHANGED
@@ -36,8 +36,13 @@ try:
36
  # when used from AutoModel, these are in the same directory
37
  from .embedding_mixin import EmbeddingMixin, Pooler
38
  except:
39
- # when running from our repo, these are in the base directory
40
- from embedding_mixin import EmbeddingMixin, Pooler
 
 
 
 
 
41
 
42
 
43
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
@@ -384,35 +389,27 @@ class MultiHeadAttention(nn.Module):
384
  attn_weights = F.softmax(attn_weights, dim=-1)
385
  context_BHLD = torch.matmul(attn_weights, value_BHLD)
386
  else:
387
- sdpa_mask = None
388
- if attention_mask is not None:
389
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
390
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
391
- flex_dtype_supported = query_BHLD.dtype in (torch.float16, torch.bfloat16)
392
- use_flex = (
393
- self.attn_backend == "flex"
394
- and flex_attention is not None
395
- and flex_dtype_supported
396
- and (attention_mask is None or flex_block_mask is not None)
397
- )
398
- if use_flex:
399
- try:
400
- context_BHLD = flex_attention(
401
- query_BHLD,
402
- key_BHLD,
403
- value_BHLD,
404
- block_mask=flex_block_mask,
405
- scale=scale,
406
- )
407
- except Exception:
408
- context_BHLD = F.scaled_dot_product_attention(
409
- query_BHLD,
410
- key_BHLD,
411
- value_BHLD,
412
- attn_mask=sdpa_mask,
413
- scale=scale,
414
  )
 
 
 
 
 
 
 
415
  else:
 
 
 
 
416
  context_BHLD = F.scaled_dot_product_attention(
417
  query_BHLD,
418
  key_BHLD,
@@ -577,11 +574,15 @@ class TransformerStack(nn.Module):
577
  if attention_mask is not None:
578
  assert attention_mask.ndim == 2, f"Expected 2D token attention mask, got shape {attention_mask.shape}."
579
  token_attention_mask = attention_mask.bool()
580
- pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
581
- attention_mask = pairwise_attention_mask.unsqueeze(1)
582
  if self.attn_backend == "flex" and not output_attentions:
 
 
 
583
  flex_block_mask = _create_pad_block_mask(token_attention_mask)
 
584
  else:
 
 
585
  flex_block_mask = None
586
  else:
587
  flex_block_mask = None
 
36
  # when used from AutoModel, these are in the same directory
37
  from .embedding_mixin import EmbeddingMixin, Pooler
38
  except:
39
+ try:
40
+ # whem importing as a submodule, embedding mixin is in the FastPLMs directory
41
+ from ..embedding_mixin import EmbeddingMixin, Pooler
42
+ except:
43
+ # when running from our repo, these are in the base directory
44
+ from embedding_mixin import EmbeddingMixin, Pooler
45
+
46
 
47
 
48
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
 
389
  attn_weights = F.softmax(attn_weights, dim=-1)
390
  context_BHLD = torch.matmul(attn_weights, value_BHLD)
391
  else:
392
+ if self.attn_backend == "flex":
393
+ assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
394
+ assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
395
+ f"Flex attention backend requires float16 or bfloat16, got {query_BHLD.dtype}."
396
+ )
397
+ if attention_mask is not None:
398
+ assert flex_block_mask is not None, (
399
+ "Flex attention backend requires a block mask when attention_mask is provided."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  )
401
+ context_BHLD = flex_attention(
402
+ query_BHLD,
403
+ key_BHLD,
404
+ value_BHLD,
405
+ block_mask=flex_block_mask,
406
+ scale=scale,
407
+ )
408
  else:
409
+ sdpa_mask = None
410
+ if attention_mask is not None:
411
+ sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
412
+ sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
413
  context_BHLD = F.scaled_dot_product_attention(
414
  query_BHLD,
415
  key_BHLD,
 
574
  if attention_mask is not None:
575
  assert attention_mask.ndim == 2, f"Expected 2D token attention mask, got shape {attention_mask.shape}."
576
  token_attention_mask = attention_mask.bool()
 
 
577
  if self.attn_backend == "flex" and not output_attentions:
578
+ assert create_block_mask is not None, (
579
+ "Flex attention backend requested but torch.create_block_mask is unavailable."
580
+ )
581
  flex_block_mask = _create_pad_block_mask(token_attention_mask)
582
+ attention_mask = None
583
  else:
584
+ pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
585
+ attention_mask = pairwise_attention_mask.unsqueeze(1)
586
  flex_block_mask = None
587
  else:
588
  flex_block_mask = None