lhallee commited on
Commit
faef8c4
·
verified ·
1 Parent(s): 2133ca6

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +11 -2
modeling_fastesm.py CHANGED
@@ -508,7 +508,7 @@ class EsmSelfAttention(nn.Module):
508
  self.scale = self.attention_head_size**-0.5
509
 
510
  self.dropout_prob = config.attention_probs_dropout_prob
511
- self.attn_backend = config.attn_backend
512
  self.position_embedding_type = position_embedding_type or getattr(
513
  config, "position_embedding_type", "absolute"
514
  )
@@ -555,7 +555,7 @@ class EsmSelfAttention(nn.Module):
555
  context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
556
  return context_layer, attention_probs
557
  else:
558
- if self.attn_backend == "flex":
559
  assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
560
  assert query_layer.dtype in (torch.float16, torch.bfloat16), f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
561
  assert flex_block_mask is not None, "Flex attention backend requires a block mask"
@@ -771,6 +771,15 @@ class FastEsmPreTrainedModel(PreTrainedModel):
771
  # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
772
  return None
773
 
 
 
 
 
 
 
 
 
 
774
 
775
  class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
776
  def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
 
508
  self.scale = self.attention_head_size**-0.5
509
 
510
  self.dropout_prob = config.attention_probs_dropout_prob
511
+ self.config = config
512
  self.position_embedding_type = position_embedding_type or getattr(
513
  config, "position_embedding_type", "absolute"
514
  )
 
555
  context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
556
  return context_layer, attention_probs
557
  else:
558
+ if self.config.attn_backend == "flex":
559
  assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
560
  assert query_layer.dtype in (torch.float16, torch.bfloat16), f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
561
  assert flex_block_mask is not None, "Flex attention backend requires a block mask"
 
771
  # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
772
  return None
773
 
774
+ @property
775
+ def attn_backend(self) -> str:
776
+ return self.config.attn_backend
777
+
778
+ @attn_backend.setter
779
+ def attn_backend(self, backend: str) -> None:
780
+ assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
781
+ self.config.attn_backend = backend
782
+
783
 
784
  class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
785
  def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):