Xuezha commited on
Commit
d10f047
·
verified ·
1 Parent(s): c3c3388

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +2 -15
modeling.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
 
4
  from transformers.modeling_outputs import CausalLMOutputWithPast
5
 
6
  from .configure import RecombinationTransformerConfig
@@ -172,26 +173,12 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
172
 
173
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
174
 
175
- @classmethod
176
- def from_pretrained(cls, *args, **kwargs):
177
- config = kwargs.pop('config', None)
178
- if config is None:
179
- config = CustomTransformerConfig.from_pretrained(*args, **kwargs)
180
- model = cls(config)
181
- state_dict = torch.load(kwargs['pretrained_model_name_or_path'], map_location='cpu')
182
- model.load_state_dict(state_dict)
183
- return model
184
-
185
- def save_pretrained(self, save_directory):
186
- torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
187
- self.config.save_pretrained(save_directory)
188
-
189
  def generate(self, input_ids, max_length=20, min_length=None, num_return_sequences=1):
190
  logits_processor = LogitsProcessorList()
191
  stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
192
 
193
  if min_length is not None:
194
- logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
195
 
196
  outputs = super().generate(
197
  input_ids=input_ids,
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
+ from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria, MinLengthLogitsProcessor
5
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
 
7
  from .configure import RecombinationTransformerConfig
 
173
 
174
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def generate(self, input_ids, max_length=20, min_length=None, num_return_sequences=1):
177
  logits_processor = LogitsProcessorList()
178
  stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
179
 
180
  if min_length is not None:
181
+ logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=config.eos_token_id))
182
 
183
  outputs = super().generate(
184
  input_ids=input_ids,