Update modeling.py
Browse files- modeling.py +7 -1
modeling.py
CHANGED
|
@@ -172,7 +172,13 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
|
|
| 172 |
logits = self.lm_head(x)
|
| 173 |
|
| 174 |
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
|
| 177 |
logits_processor = LogitsProcessorList()
|
| 178 |
|
|
|
|
| 172 |
logits = self.lm_head(x)
|
| 173 |
|
| 174 |
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
| 175 |
+
|
| 176 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
|
| 177 |
+
if past:
|
| 178 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 179 |
+
|
| 180 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
| 181 |
+
|
| 182 |
def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
|
| 183 |
logits_processor = LogitsProcessorList()
|
| 184 |
|