Xuezha commited on
Commit
d541062
·
verified ·
1 Parent(s): 0f5b10f

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +3 -2
modeling.py CHANGED
@@ -173,15 +173,16 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
173
 
174
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
175
 
176
- def generate(self, input_ids, max_length=20, min_length=None, num_return_sequences=1):
177
  logits_processor = LogitsProcessorList()
178
  stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
179
 
180
  if min_length is not None:
181
- logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=config.eos_token_id))
182
 
183
  outputs = super().generate(
184
  input_ids=input_ids,
 
185
  max_length=max_length,
186
  num_return_sequences=num_return_sequences,
187
  logits_processor=logits_processor,
 
173
 
174
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
175
 
176
+ def generate(self, input_ids, attention_mask=None, max_length=20, min_length=None, num_return_sequences=1):
177
  logits_processor = LogitsProcessorList()
178
  stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
179
 
180
  if min_length is not None:
181
+ logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
182
 
183
  outputs = super().generate(
184
  input_ids=input_ids,
185
+ attention_mask=attention_mask,
186
  max_length=max_length,
187
  num_return_sequences=num_return_sequences,
188
  logits_processor=logits_processor,