Update modeling_rcps.py
Browse filesEnable `prenorm=False` for RCPSAddNormWrapper which prevent returning the residual
- modeling_rcps.py +4 -3
modeling_rcps.py
CHANGED
|
@@ -101,11 +101,12 @@ class RCPSAddNormWrapper(RCPSWrapper):
|
|
| 101 |
def __init__(self, submodule: nn.Module):
|
| 102 |
super().__init__(submodule)
|
| 103 |
|
| 104 |
-
def forward(self, x, residual=None):
|
| 105 |
"""
|
| 106 |
Args:
|
| 107 |
x: Input tensor of shape (batch_size, seq_len, channels)
|
| 108 |
residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
|
|
|
|
| 109 |
"""
|
| 110 |
n_channels = x.shape[-1]
|
| 111 |
if residual is None:
|
|
@@ -123,7 +124,7 @@ class RCPSAddNormWrapper(RCPSWrapper):
|
|
| 123 |
residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
|
| 124 |
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
|
| 125 |
|
| 126 |
-
return x, residual
|
| 127 |
|
| 128 |
|
| 129 |
class RCPSMambaBlock(nn.Module):
|
|
@@ -159,7 +160,7 @@ class RCPSMambaBlock(nn.Module):
|
|
| 159 |
inference_params: inference parameters for mixer.
|
| 160 |
"""
|
| 161 |
if not self.fused_add_norm:
|
| 162 |
-
hidden_states, residual = self.norm(hidden_states, residual=residual)
|
| 163 |
if self.residual_in_fp32:
|
| 164 |
residual = residual.to(torch.float32)
|
| 165 |
else:
|
|
|
|
| 101 |
def __init__(self, submodule: nn.Module):
|
| 102 |
super().__init__(submodule)
|
| 103 |
|
| 104 |
+
def forward(self, x, residual=None, prenorm=True):
|
| 105 |
"""
|
| 106 |
Args:
|
| 107 |
x: Input tensor of shape (batch_size, seq_len, channels)
|
| 108 |
residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
|
| 109 |
+
prenorm: Whether to return residual.
|
| 110 |
"""
|
| 111 |
n_channels = x.shape[-1]
|
| 112 |
if residual is None:
|
|
|
|
| 124 |
residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
|
| 125 |
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
|
| 126 |
|
| 127 |
+
return x if not prenorm else (x, residual)
|
| 128 |
|
| 129 |
|
| 130 |
class RCPSMambaBlock(nn.Module):
|
|
|
|
| 160 |
inference_params: inference parameters for mixer.
|
| 161 |
"""
|
| 162 |
if not self.fused_add_norm:
|
| 163 |
+
hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
|
| 164 |
if self.residual_in_fp32:
|
| 165 |
residual = residual.to(torch.float32)
|
| 166 |
else:
|