lhallee commited on
Commit
3dc87a1
·
verified ·
1 Parent(s): 5d6dbf2

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +4 -4
modeling_esm_plusplus.py CHANGED
@@ -70,7 +70,7 @@ class ESMplusplusConfig(PretrainedConfig):
70
  problem_type: str | None = None,
71
  dropout: float = 0.0,
72
  initializer_range: float = 0.02,
73
- attn_backend: str = "flex",
74
  **kwargs,
75
  ):
76
  super().__init__(**kwargs)
@@ -298,7 +298,7 @@ class MultiHeadAttention(nn.Module):
298
  self,
299
  d_model: int,
300
  n_heads: int,
301
- attn_backend: str = "flex",
302
  ):
303
  super().__init__()
304
  self.d_model = d_model
@@ -438,7 +438,7 @@ class UnifiedTransformerBlock(nn.Module):
438
  residue_scaling_factor: float = 1,
439
  expansion_ratio: float = 8 / 3,
440
  dropout: float = 0.0,
441
- attn_backend: str = "flex",
442
  ):
443
  super().__init__()
444
  self.attn = MultiHeadAttention(
@@ -512,7 +512,7 @@ class TransformerStack(nn.Module):
512
  n_heads: int,
513
  n_layers: int,
514
  dropout: float = 0.0,
515
- attn_backend: str = "flex",
516
  ):
517
  super().__init__()
518
  self.attn_backend = attn_backend
 
70
  problem_type: str | None = None,
71
  dropout: float = 0.0,
72
  initializer_range: float = 0.02,
73
+ attn_backend: str = "sdpa",
74
  **kwargs,
75
  ):
76
  super().__init__(**kwargs)
 
298
  self,
299
  d_model: int,
300
  n_heads: int,
301
+ attn_backend: str = "sdpa",
302
  ):
303
  super().__init__()
304
  self.d_model = d_model
 
438
  residue_scaling_factor: float = 1,
439
  expansion_ratio: float = 8 / 3,
440
  dropout: float = 0.0,
441
+ attn_backend: str = "sdpa",
442
  ):
443
  super().__init__()
444
  self.attn = MultiHeadAttention(
 
512
  n_heads: int,
513
  n_layers: int,
514
  dropout: float = 0.0,
515
+ attn_backend: str = "sdpa",
516
  ):
517
  super().__init__()
518
  self.attn_backend = attn_backend