Xuezha commited on
Commit
148f88a
·
verified ·
1 Parent(s): 76e59d9

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +7 -1
modeling.py CHANGED
@@ -172,7 +172,13 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
172
  logits = self.lm_head(x)
173
 
174
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
175
-
 
 
 
 
 
 
176
  def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
177
  logits_processor = LogitsProcessorList()
178
 
 
172
  logits = self.lm_head(x)
173
 
174
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
175
+
176
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
177
+ if past:
178
+ input_ids = input_ids[:, -1].unsqueeze(-1)
179
+
180
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
181
+
182
  def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
183
  logits_processor = LogitsProcessorList()
184