Update modeling.py
Browse files- modeling.py +2 -15
modeling.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from transformers import PreTrainedModel
|
|
|
|
| 4 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 5 |
|
| 6 |
from .configure import RecombinationTransformerConfig
|
|
@@ -172,26 +173,12 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
|
|
| 172 |
|
| 173 |
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
| 174 |
|
| 175 |
-
@classmethod
|
| 176 |
-
def from_pretrained(cls, *args, **kwargs):
|
| 177 |
-
config = kwargs.pop('config', None)
|
| 178 |
-
if config is None:
|
| 179 |
-
config = CustomTransformerConfig.from_pretrained(*args, **kwargs)
|
| 180 |
-
model = cls(config)
|
| 181 |
-
state_dict = torch.load(kwargs['pretrained_model_name_or_path'], map_location='cpu')
|
| 182 |
-
model.load_state_dict(state_dict)
|
| 183 |
-
return model
|
| 184 |
-
|
| 185 |
-
def save_pretrained(self, save_directory):
|
| 186 |
-
torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
|
| 187 |
-
self.config.save_pretrained(save_directory)
|
| 188 |
-
|
| 189 |
def generate(self, input_ids, max_length=20, min_length=None, num_return_sequences=1):
|
| 190 |
logits_processor = LogitsProcessorList()
|
| 191 |
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
|
| 192 |
|
| 193 |
if min_length is not None:
|
| 194 |
-
logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=
|
| 195 |
|
| 196 |
outputs = super().generate(
|
| 197 |
input_ids=input_ids,
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from transformers import PreTrainedModel
|
| 4 |
+
from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria, MinLengthLogitsProcessor
|
| 5 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 6 |
|
| 7 |
from .configure import RecombinationTransformerConfig
|
|
|
|
| 173 |
|
| 174 |
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
def generate(self, input_ids, max_length=20, min_length=None, num_return_sequences=1):
|
| 177 |
logits_processor = LogitsProcessorList()
|
| 178 |
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
|
| 179 |
|
| 180 |
if min_length is not None:
|
| 181 |
+
logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=config.eos_token_id))
|
| 182 |
|
| 183 |
outputs = super().generate(
|
| 184 |
input_ids=input_ids,
|