Update modeling_neollm.py
Browse files- modeling_neollm.py +5 -1
modeling_neollm.py
CHANGED
|
@@ -1210,10 +1210,14 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1210 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1211 |
|
| 1212 |
self.post_init()
|
| 1213 |
-
def tie_weights(self):
|
| 1214 |
"""
|
| 1215 |
Tie the weights between the input embeddings and the output embeddings.
|
| 1216 |
Required for v5.0 compatibility.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1217 |
"""
|
| 1218 |
self._tie_or_clone_weights(self.lm_head, self.model.embed_tokens)
|
| 1219 |
def forward(
|
|
|
|
| 1210 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1211 |
|
| 1212 |
self.post_init()
|
| 1213 |
+
def tie_weights(self, missing_keys=None, recompute_mapping=True):
|
| 1214 |
"""
|
| 1215 |
Tie the weights between the input embeddings and the output embeddings.
|
| 1216 |
Required for v5.0 compatibility.
|
| 1217 |
+
|
| 1218 |
+
Args:
|
| 1219 |
+
missing_keys: List of missing keys from checkpoint loading (v5.0 parameter)
|
| 1220 |
+
recompute_mapping: Whether to recompute tied weights mapping (v5.0 parameter)
|
| 1221 |
"""
|
| 1222 |
self._tie_or_clone_weights(self.lm_head, self.model.embed_tokens)
|
| 1223 |
def forward(
|