Xuezha commited on
Commit
76e59d9
·
verified ·
1 Parent(s): d541062

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +2 -4
modeling.py CHANGED
@@ -173,9 +173,8 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
173
 
174
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
175
 
176
- def generate(self, input_ids, attention_mask=None, max_length=20, min_length=None, num_return_sequences=1):
177
  logits_processor = LogitsProcessorList()
178
- stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
179
 
180
  if min_length is not None:
181
  logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
@@ -185,8 +184,7 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
185
  attention_mask=attention_mask,
186
  max_length=max_length,
187
  num_return_sequences=num_return_sequences,
188
- logits_processor=logits_processor,
189
- stopping_criteria=stopping_criteria
190
  )
191
 
192
  return outputs
 
173
 
174
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
175
 
176
+ def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
177
  logits_processor = LogitsProcessorList()
 
178
 
179
  if min_length is not None:
180
  logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
 
184
  attention_mask=attention_mask,
185
  max_length=max_length,
186
  num_return_sequences=num_return_sequences,
187
+ logits_processor=logits_processor
 
188
  )
189
 
190
  return outputs