mineself2016 commited on
Commit
020c027
·
verified ·
1 Parent(s): df0709f

Align to Mamba2 checkpoint keys and config

Browse files
Files changed (1) hide show
  1. modeling_genemamba.py +3 -3
modeling_genemamba.py CHANGED
@@ -219,8 +219,8 @@ class GeneMambaModel(GeneMambaPreTrainedModel):
219
  num_hidden_layers=config.num_hidden_layers
220
  )
221
 
222
- # Final layer normalization
223
- self.norm = RMSNorm(config.hidden_size)
224
 
225
  self.apply(self._init_weights)
226
 
@@ -254,7 +254,7 @@ class GeneMambaModel(GeneMambaPreTrainedModel):
254
  hidden_states = self.mamba_mixer(hidden_states, attention_mask)
255
 
256
  # Apply final normalization
257
- hidden_states = self.norm(hidden_states)
258
 
259
  # Compute pooled embedding (cell representation)
260
  if self.config.embedding_pooling == "CLS":
 
219
  num_hidden_layers=config.num_hidden_layers
220
  )
221
 
222
+ # Final layer normalization (kept as norm_f to match checkpoint key names)
223
+ self.norm_f = RMSNorm(config.hidden_size)
224
 
225
  self.apply(self._init_weights)
226
 
 
254
  hidden_states = self.mamba_mixer(hidden_states, attention_mask)
255
 
256
  # Apply final normalization
257
+ hidden_states = self.norm_f(hidden_states)
258
 
259
  # Compute pooled embedding (cell representation)
260
  if self.config.embedding_pooling == "CLS":