Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +12 -0
modeling_gpt_refact.py
CHANGED
|
@@ -369,6 +369,12 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
| 369 |
# Initialize weights and apply final processing
|
| 370 |
self.post_init()
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
def forward(
|
| 373 |
self,
|
| 374 |
input_ids: Optional[torch.Tensor] = None,
|
|
@@ -509,6 +515,12 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
|
|
| 509 |
# Initialize weights and apply final processing
|
| 510 |
self.post_init()
|
| 511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 513 |
if inputs_embeds is not None and past_key_values is None:
|
| 514 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
|
| 369 |
# Initialize weights and apply final processing
|
| 370 |
self.post_init()
|
| 371 |
|
| 372 |
+
def get_input_embeddings(self):
|
| 373 |
+
return self.wte
|
| 374 |
+
|
| 375 |
+
def set_input_embeddings(self, new_embeddings):
|
| 376 |
+
self.wte = new_embeddings
|
| 377 |
+
|
| 378 |
def forward(
|
| 379 |
self,
|
| 380 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 515 |
# Initialize weights and apply final processing
|
| 516 |
self.post_init()
|
| 517 |
|
| 518 |
+
def get_output_embeddings(self):
|
| 519 |
+
return self.lm_head
|
| 520 |
+
|
| 521 |
+
def set_output_embeddings(self, new_embeddings):
|
| 522 |
+
self.lm_head = new_embeddings
|
| 523 |
+
|
| 524 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 525 |
if inputs_embeds is not None and past_key_values is None:
|
| 526 |
model_inputs = {"inputs_embeds": inputs_embeds}
|