Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +6 -6
modeling_esm_plusplus.py
CHANGED
|
@@ -249,7 +249,7 @@ class SwiGLU(nn.Module):
|
|
| 249 |
return F.silu(x1) * x2
|
| 250 |
|
| 251 |
|
| 252 |
-
def swiglu_ln_ffn(d_model: int, expansion_ratio: float
|
| 253 |
"""Create SwiGLU feedforward network with layer normalization."""
|
| 254 |
return nn.Sequential(
|
| 255 |
nn.LayerNorm(d_model),
|
|
@@ -257,7 +257,6 @@ def swiglu_ln_ffn(d_model: int, expansion_ratio: float, dropout: float = 0.0) ->
|
|
| 257 |
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
|
| 258 |
),
|
| 259 |
SwiGLU(),
|
| 260 |
-
nn.Dropout(dropout),
|
| 261 |
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
|
| 262 |
)
|
| 263 |
|
|
@@ -377,8 +376,9 @@ class UnifiedTransformerBlock(nn.Module):
|
|
| 377 |
):
|
| 378 |
super().__init__()
|
| 379 |
self.attn = MultiHeadAttention(d_model, n_heads)
|
| 380 |
-
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio
|
| 381 |
self.scaling_factor = residue_scaling_factor
|
|
|
|
| 382 |
|
| 383 |
def forward(
|
| 384 |
self,
|
|
@@ -396,9 +396,8 @@ class UnifiedTransformerBlock(nn.Module):
|
|
| 396 |
Output tensor after transformer block, and optionally attention weights
|
| 397 |
"""
|
| 398 |
attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
|
| 399 |
-
x = x + attn_output / self.scaling_factor
|
| 400 |
-
|
| 401 |
-
x = x + r3
|
| 402 |
if output_attentions:
|
| 403 |
return x, attn_weights
|
| 404 |
return x
|
|
@@ -431,6 +430,7 @@ class TransformerStack(nn.Module):
|
|
| 431 |
d_model: Model dimension
|
| 432 |
n_heads: Number of attention heads
|
| 433 |
n_layers: Number of transformer layers
|
|
|
|
| 434 |
"""
|
| 435 |
def __init__(
|
| 436 |
self,
|
|
|
|
| 249 |
return F.silu(x1) * x2
|
| 250 |
|
| 251 |
|
| 252 |
+
def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
|
| 253 |
"""Create SwiGLU feedforward network with layer normalization."""
|
| 254 |
return nn.Sequential(
|
| 255 |
nn.LayerNorm(d_model),
|
|
|
|
| 257 |
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
|
| 258 |
),
|
| 259 |
SwiGLU(),
|
|
|
|
| 260 |
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
|
| 261 |
)
|
| 262 |
|
|
|
|
| 376 |
):
|
| 377 |
super().__init__()
|
| 378 |
self.attn = MultiHeadAttention(d_model, n_heads)
|
| 379 |
+
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
|
| 380 |
self.scaling_factor = residue_scaling_factor
|
| 381 |
+
self.dropout = nn.Dropout(dropout)
|
| 382 |
|
| 383 |
def forward(
|
| 384 |
self,
|
|
|
|
| 396 |
Output tensor after transformer block, and optionally attention weights
|
| 397 |
"""
|
| 398 |
attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
|
| 399 |
+
x = x + self.dropout(attn_output) / self.scaling_factor
|
| 400 |
+
x = x + self.dropout(self.ffn(x)) / self.scaling_factor
|
|
|
|
| 401 |
if output_attentions:
|
| 402 |
return x, attn_weights
|
| 403 |
return x
|
|
|
|
| 430 |
d_model: Model dimension
|
| 431 |
n_heads: Number of attention heads
|
| 432 |
n_layers: Number of transformer layers
|
| 433 |
+
dropout: Dropout rate
|
| 434 |
"""
|
| 435 |
def __init__(
|
| 436 |
self,
|