Update modeling_rcps.py
Browse files- modeling_rcps.py +5 -0
modeling_rcps.py
CHANGED
|
@@ -148,6 +148,11 @@ class RCPSMambaBlock(nn.Module):
|
|
| 148 |
self.mixer = RCPSWrapper(mixer_cls(dim))
|
| 149 |
norm_f = norm_cls(dim)
|
| 150 |
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
def forward(
|
| 153 |
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
|
|
|
| 148 |
self.mixer = RCPSWrapper(mixer_cls(dim))
|
| 149 |
norm_f = norm_cls(dim)
|
| 150 |
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
|
| 151 |
+
if self.fused_add_norm:
|
| 152 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
| 153 |
+
assert isinstance(
|
| 154 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
| 155 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
| 156 |
|
| 157 |
def forward(
|
| 158 |
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|