Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +4 -4
modeling_esm_plusplus.py
CHANGED
|
@@ -70,7 +70,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
| 70 |
problem_type: str | None = None,
|
| 71 |
dropout: float = 0.0,
|
| 72 |
initializer_range: float = 0.02,
|
| 73 |
-
attn_backend: str = "
|
| 74 |
**kwargs,
|
| 75 |
):
|
| 76 |
super().__init__(**kwargs)
|
|
@@ -298,7 +298,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 298 |
self,
|
| 299 |
d_model: int,
|
| 300 |
n_heads: int,
|
| 301 |
-
attn_backend: str = "
|
| 302 |
):
|
| 303 |
super().__init__()
|
| 304 |
self.d_model = d_model
|
|
@@ -438,7 +438,7 @@ class UnifiedTransformerBlock(nn.Module):
|
|
| 438 |
residue_scaling_factor: float = 1,
|
| 439 |
expansion_ratio: float = 8 / 3,
|
| 440 |
dropout: float = 0.0,
|
| 441 |
-
attn_backend: str = "
|
| 442 |
):
|
| 443 |
super().__init__()
|
| 444 |
self.attn = MultiHeadAttention(
|
|
@@ -512,7 +512,7 @@ class TransformerStack(nn.Module):
|
|
| 512 |
n_heads: int,
|
| 513 |
n_layers: int,
|
| 514 |
dropout: float = 0.0,
|
| 515 |
-
attn_backend: str = "
|
| 516 |
):
|
| 517 |
super().__init__()
|
| 518 |
self.attn_backend = attn_backend
|
|
|
|
| 70 |
problem_type: str | None = None,
|
| 71 |
dropout: float = 0.0,
|
| 72 |
initializer_range: float = 0.02,
|
| 73 |
+
attn_backend: str = "sdpa",
|
| 74 |
**kwargs,
|
| 75 |
):
|
| 76 |
super().__init__(**kwargs)
|
|
|
|
| 298 |
self,
|
| 299 |
d_model: int,
|
| 300 |
n_heads: int,
|
| 301 |
+
attn_backend: str = "sdpa",
|
| 302 |
):
|
| 303 |
super().__init__()
|
| 304 |
self.d_model = d_model
|
|
|
|
| 438 |
residue_scaling_factor: float = 1,
|
| 439 |
expansion_ratio: float = 8 / 3,
|
| 440 |
dropout: float = 0.0,
|
| 441 |
+
attn_backend: str = "sdpa",
|
| 442 |
):
|
| 443 |
super().__init__()
|
| 444 |
self.attn = MultiHeadAttention(
|
|
|
|
| 512 |
n_heads: int,
|
| 513 |
n_layers: int,
|
| 514 |
dropout: float = 0.0,
|
| 515 |
+
attn_backend: str = "sdpa",
|
| 516 |
):
|
| 517 |
super().__init__()
|
| 518 |
self.attn_backend = attn_backend
|