Align to Mamba2 checkpoint keys and config
Browse files- modeling_genemamba.py +3 -3
modeling_genemamba.py
CHANGED
|
@@ -219,8 +219,8 @@ class GeneMambaModel(GeneMambaPreTrainedModel):
|
|
| 219 |
num_hidden_layers=config.num_hidden_layers
|
| 220 |
)
|
| 221 |
|
| 222 |
-
# Final layer normalization
|
| 223 |
-
self.
|
| 224 |
|
| 225 |
self.apply(self._init_weights)
|
| 226 |
|
|
@@ -254,7 +254,7 @@ class GeneMambaModel(GeneMambaPreTrainedModel):
|
|
| 254 |
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
|
| 255 |
|
| 256 |
# Apply final normalization
|
| 257 |
-
hidden_states = self.
|
| 258 |
|
| 259 |
# Compute pooled embedding (cell representation)
|
| 260 |
if self.config.embedding_pooling == "CLS":
|
|
|
|
| 219 |
num_hidden_layers=config.num_hidden_layers
|
| 220 |
)
|
| 221 |
|
| 222 |
+
# Final layer normalization (kept as norm_f to match checkpoint key names)
|
| 223 |
+
self.norm_f = RMSNorm(config.hidden_size)
|
| 224 |
|
| 225 |
self.apply(self._init_weights)
|
| 226 |
|
|
|
|
| 254 |
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
|
| 255 |
|
| 256 |
# Apply final normalization
|
| 257 |
+
hidden_states = self.norm_f(hidden_states)
|
| 258 |
|
| 259 |
# Compute pooled embedding (cell representation)
|
| 260 |
if self.config.embedding_pooling == "CLS":
|