Xuezha commited on
Commit
c3c3388
·
verified ·
1 Parent(s): b14cf69

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +31 -0
modeling.py CHANGED
@@ -171,3 +171,34 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
171
  logits = self.lm_head(x)
172
 
173
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  logits = self.lm_head(x)
172
 
173
  return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
174
+
175
+ @classmethod
176
+ def from_pretrained(cls, *args, **kwargs):
177
+ config = kwargs.pop('config', None)
178
+ if config is None:
179
+ config = CustomTransformerConfig.from_pretrained(*args, **kwargs)
180
+ model = cls(config)
181
+ state_dict = torch.load(kwargs['pretrained_model_name_or_path'], map_location='cpu')
182
+ model.load_state_dict(state_dict)
183
+ return model
184
+
185
+ def save_pretrained(self, save_directory):
186
+ torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
187
+ self.config.save_pretrained(save_directory)
188
+
189
+ def generate(self, input_ids, max_length=20, min_length=None, num_return_sequences=1):
190
+ logits_processor = LogitsProcessorList()
191
+ stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
192
+
193
+ if min_length is not None:
194
+ logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
195
+
196
+ outputs = super().generate(
197
+ input_ids=input_ids,
198
+ max_length=max_length,
199
+ num_return_sequences=num_return_sequences,
200
+ logits_processor=logits_processor,
201
+ stopping_criteria=stopping_criteria
202
+ )
203
+
204
+ return outputs